From 19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Mon, 5 Aug 2024 12:28:13 -0600 Subject: [PATCH 01/53] Fix isnan namespace qualification in cutlass/functional.h (#1679) * Fix unrelated MSVC build warnings * Fix use of isnan in functional.h Correct namespace qualification of isnan in functional.h so that it invokes cutlass::isnan for half_t, instead of converting half_t to float and invoking std::isnan (on host, or ::isnan on device). --- include/cutlass/detail/helper_macros.hpp | 38 ++++++++++ .../sm90_epilogue_tma_warpspecialized.hpp | 5 +- include/cutlass/functional.h | 69 ++++++++--------- test/unit/core/functional.cu | 76 +++++++++++++++++++ 4 files changed, 152 insertions(+), 36 deletions(-) diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp index 926ccc7308..4cd895f147 100644 --- a/include/cutlass/detail/helper_macros.hpp +++ b/include/cutlass/detail/helper_macros.hpp @@ -95,6 +95,44 @@ CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) #endif #endif +// CUTLASS_CMATH_NAMESPACE is the namespace where code can find +// functions like isnan and log. Such functions are in +// the std namespace in host code, but in the global namespace +// in device code. +// +// The intended use case for this macro is in "using" declarations +// for making argument-dependent lookup (ADL) work in generic code. +// For example, if T is cutlass::half_t, the following code will +// invoke cutlass::isnan(half_t). If T is float, it will invoke +// std::isnan on host and ::isnan on device. (CUTLASS's support +// for NVRTC prevents it from using things in the std namespace +// in device code.) Correct use of "using" declarations can help +// avoid unexpected implicit conversions, like from half_t to float. +// +// template +// bool foo(T x) { +// using CUTLASS_CMATH_NAMESPACE :: isnan; +// return isnan(x); +// } +// +// Without this macro, one would need to write the following. +// +// template +// bool foo(T x) { +// #if defined(__CUDA_ARCH__) +// using ::isnan; +// #else +// using std::isnan; +// #endif +// return isnan(x); +// } + +#if defined(__CUDA_ARCH__) +# define CUTLASS_CMATH_NAMESPACE +#else +# define CUTLASS_CMATH_NAMESPACE std +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 56b55292a8..4b2157b6c7 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -670,7 +670,8 @@ class CollectiveEpilogue< // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions // Sync requirements of smem reuse may preclude this optimization // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD - int epi_m_prev = 0, epi_n_prev = 0; + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock"); // The TMA store sequence for one subtile iteration @@ -725,7 +726,7 @@ class CollectiveEpilogue< for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { CUTLASS_PRAGMA_UNROLL for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { - bool is_first_iteration = epi_m == 0 && epi_n == 0; + [[maybe_unused]] bool is_first_iteration = epi_m == 0 && epi_n == 0; bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index da946c7c5f..65e49d5290 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -369,11 +369,14 @@ template struct maximum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { -#if defined(__CUDA_ARCH__) - return lhs > rhs or ::isnan(lhs) ? lhs : rhs; -#else - return lhs > rhs or std::isnan(lhs) ? lhs : rhs; -#endif + using CUTLASS_CMATH_NAMESPACE :: isnan; + + // Call isnan unqualified, so argument-dependent lookup (ADL) + // will find overloads such as cutlass::isnan(half_t). + // Calling ::isnan or std::isnan directly would force + // implicit conversions to float of custom number types + // in the cutlass namespace (e.g., cutlass::half_t). + return lhs > rhs || isnan(lhs) ? lhs : rhs; } }; @@ -389,15 +392,14 @@ template <> struct maximum { CUTLASS_HOST_DEVICE float operator()(float const lhs, float const rhs) const { - float res; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + float res; asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); -#elif defined(__CUDA_ARCH__) - res = lhs > rhs or ::isnan(lhs) ? lhs : rhs; + return res; #else - res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; + using CUTLASS_CMATH_NAMESPACE :: isnan; + return lhs > rhs || isnan(lhs) ? lhs : rhs; #endif - return res; } }; @@ -427,11 +429,9 @@ template struct minimum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { -#if defined(__CUDA_ARCH__) - return lhs < rhs or ::isnan(lhs) ? lhs : rhs; -#else - return lhs < rhs or std::isnan(lhs) ? lhs : rhs; -#endif + using CUTLASS_CMATH_NAMESPACE :: isnan; + + return lhs < rhs || isnan(lhs) ? lhs : rhs; } }; @@ -512,6 +512,8 @@ template struct guarded_multiply_add { CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const { + using CUTLASS_CMATH_NAMESPACE :: isnan; + if (isnan(a) || isnan(b)) { return C(0); } @@ -531,7 +533,10 @@ struct guarded_multiply_add { : "h"(*reinterpret_cast(&a)), "h"(*reinterpret_cast(&b)), "h"(*reinterpret_cast(&c))); return result; #else - if (isnan(a) || isnan(b)) { + // Namespace-qualifying isnan as cutlass::isnan saves the compiler + // the trouble of argument-dependent lookup. Calling std::isnan or + // ::isnan here would result in unwanted implicit conversion to float. + if (cutlass::isnan(a) || cutlass::isnan(b)) { return half_t(0); } return a * b + c; @@ -544,13 +549,9 @@ template struct guarded_multiply_add_relu0 { CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const { - if ( -#if defined(__CUDA_ARCH__) - ::isnan(a) || ::isnan(b) -#else - std::isnan(a) || std::isnan(b) -#endif - ) { + using CUTLASS_CMATH_NAMESPACE :: isnan; + + if (isnan(a) || isnan(b)) { return C(0); } maximum mx; @@ -569,13 +570,7 @@ struct guarded_multiply_add_relu0 { : "h"(*reinterpret_cast(&a)), "h"(*reinterpret_cast(&b)), "h"(*reinterpret_cast(&c))); return result; #else - if ( -#if defined(__CUDA_ARCH__) - ::isnan(a) || ::isnan(b) -#else - std::isnan(a) || std::isnan(b) -#endif - ) { + if (cutlass::isnan(a) || cutlass::isnan(b)) { return half_t(0); } maximum mx; @@ -782,6 +777,10 @@ struct atomic_add { #if defined(__CUDA_ARCH__) atomicAdd(ptr, data); +#else + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(data); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -793,8 +792,9 @@ struct atomic_add void operator()(double *ptr, const double &data) { #if !defined(__CUDA_ARCH__) - CUTLASS_UNUSED(ptr); - CUTLASS_UNUSED(data); + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(data); + CUTLASS_NOT_IMPLEMENTED(); #elif (__CUDA_ARCH__ >= 600) atomicAdd(ptr, data); #else @@ -819,8 +819,9 @@ struct atomic_add void operator()(half2 *ptr, const half2 &data) { #if !defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)) - CUTLASS_UNUSED(ptr); - CUTLASS_UNUSED(data); + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(data); + CUTLASS_NOT_IMPLEMENTED(); #else // Vector-2 atomic reduction requires .target sm_60 or higher uint32_t word = reinterpret_cast(data); diff --git a/test/unit/core/functional.cu b/test/unit/core/functional.cu index 174a089555..2b914b255f 100644 --- a/test/unit/core/functional.cu +++ b/test/unit/core/functional.cu @@ -491,4 +491,80 @@ TEST(Functional, multiply_add_quaternion_f32) { Functional_multiply_add_QuaternionT(); } +namespace cutlass_test { + +__global__ void +test_cutlass_maximum(cutlass::half_t const* in1, cutlass::half_t const* in2, cutlass::half_t* out) +{ + { + constexpr bool propagate_NaN = true; + cutlass::maximum op; + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 + && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + *out = op(*in1, *in2); + } + } + { + constexpr bool propagate_NaN = false; + cutlass::maximum op; + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 + && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + *out = op(*in1, *in2); + } + } +} + +} // cutlass_test + +// Test compilation on both host and device. +TEST(Functional, maximum_half_host_propagate_NaN) { + constexpr bool propagate_NaN = true; + cutlass::maximum op; + cutlass::half_t x(1.0f); + cutlass::half_t y(2.0f); + + auto result = op(x, y); + static_assert(std::is_same_v); + EXPECT_EQ(result, y); + result = op(y, x); + EXPECT_EQ(result, y); +} + +TEST(Functional, maximum_half_host_dont_propagate_NaN) { + constexpr bool propagate_NaN = false; + cutlass::maximum op; + cutlass::half_t x(1.0f); + cutlass::half_t y(2.0f); + + auto result = op(x, y); + static_assert(std::is_same_v); + EXPECT_EQ(result, y); + result = op(y, x); + EXPECT_EQ(result, y); +} + +TEST(Function, maximum_half_device) { + using Tensor = cutlass::HostTensor; + + Tensor in1({1, 1}); + Tensor in2({1, 1}); + Tensor out({1, 1}); + in1.host_data()[0] = cutlass::half_t(1.0f); + in2.host_data()[0] = cutlass::half_t(2.0f); + out.host_data()[0] = cutlass::half_t(0.0f); + + in1.sync_device(); + in2.sync_device(); + out.sync_device(); + + cutlass_test::test_cutlass_maximum<<< 1, 1 >>>( + in1.device_data(), + in2.device_data(), + out.device_data() + ); + out.sync_host(); + + EXPECT_EQ(out.host_data()[0], 2.0f); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// From e22ba590cd8a7eebea8f53c81b5740d905021654 Mon Sep 17 00:00:00 2001 From: chenwei <15601910741@163.com> Date: Tue, 6 Aug 2024 23:15:18 +0800 Subject: [PATCH 02/53] support data type w2 used in cutlass_library (#1517) --- python/cutlass_library/library.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 710cad31ca..ab99bae83a 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -69,6 +69,7 @@ class GeneratorTarget(enum.Enum): class DataType(enum.Enum): void = enum_auto() # primarily used to disable C tensor for epilogues b1 = enum_auto() + u2 = enum_auto() u4 = enum_auto() u8 = enum_auto() u16 = enum_auto() @@ -119,6 +120,7 @@ class DataType(enum.Enum): DataTypeNames = { DataType.void: "void", DataType.b1: "b1", + DataType.u2: "u2", DataType.u4: "u4", DataType.u8: "u8", DataType.u16: "u16", @@ -156,6 +158,7 @@ class DataType(enum.Enum): DataTypeTag = { DataType.void: "void", DataType.b1: "cutlass::uint1b_t", + DataType.u2: "cutlass::uint2b_t", DataType.u4: "cutlass::uint4b_t", DataType.u8: "uint8_t", DataType.u16: "uint16_t", @@ -193,6 +196,7 @@ class DataType(enum.Enum): DataTypeSize = { DataType.void: 0, DataType.b1: 1, + DataType.u2: 2, DataType.u4: 4, DataType.u8: 8, DataType.u16: 16, From 2049c6c5a22bcc5c081a7c172eb4978f44602cb3 Mon Sep 17 00:00:00 2001 From: dePaul Miller Date: Thu, 8 Aug 2024 10:56:23 -0700 Subject: [PATCH 03/53] 5476 cutlass 3x gemm kernels (#1695) Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com> --- python/cutlass_library/generator.py | 159 ++++++++++++++++++++-------- 1 file changed, 115 insertions(+), 44 deletions(-) diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 8aa18b4b15..1f2eb86ecc 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -4960,47 +4960,68 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): DataType.bf16, DataType.bf16, DataType.f32, OpcodeClass.TensorOp, MathOperation.multiply_add), + MathInstruction( + [64, 256, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), ] min_cc = 90 max_cc = 90 for math_inst in math_instructions: - tile_descriptions_small = [ - # Not compatible with TmaWarpSpecializedCooperative - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - ] - tile_descriptions_medium = [ - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - ] - tile_descriptions_large = [ - TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]), - # 128x256x128 - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]), - ] - tile_descriptions = tile_descriptions_medium + tile_descriptions_large + tile_descriptions = [] + tile_descriptions_small = [] + tile_descriptions_medium = [] + tile_descriptions_large = [] + + if math_inst.instruction_shape[1] == 128: + tile_descriptions_small = [ + # Not compatible with TmaWarpSpecializedCooperative + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions_medium = [ + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + ] + tile_descriptions_large = [ + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions = tile_descriptions_medium + tile_descriptions_large + else: + tile_descriptions = [ + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 2, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] data_type = { "a_type" : math_inst.element_a, @@ -5043,7 +5064,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): # persistent kernels with TMA epilogues if CudaToolkitVersionSatisfies(cuda_version, 12, 1): # not enough smem for 256x128 f32 out with C allocation - if data_type["d_type"] == DataType.f32: + if data_type["d_type"] == DataType.f32 and len(tile_descriptions_medium) > 0: CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type, [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) @@ -5490,20 +5511,30 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): DataType.u8, DataType.u8, DataType.s32, OpcodeClass.TensorOp, MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.u8, DataType.u8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), ] min_cc = 90 max_cc = 90 for math_inst in math_instructions: - # 64x128x128 + # 64x128x128 or 64x256x128 tile_descriptions_small = [ TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), ] - # 128x128x128 + # 128x128x128 or 128x256x128 tile_descriptions_medium = [ TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), @@ -5670,6 +5701,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): DataType.e5m2, DataType.e5m2, DataType.f32, OpcodeClass.TensorOp, MathOperation.multiply_add), + # inst 64x256x32 + MathInstruction( + [64, 256, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 256, 32], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), ] min_cc = 90 @@ -5788,9 +5840,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - # 128x256x128 - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions = [ # 128x128x128 @@ -5801,6 +5850,27 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] + elif math_inst.instruction_shape[1] == 256: + tile_descriptions_small = [ + # 64x256x128 + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions_large = [] + tile_descriptions = [ + # 128x256x128 + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + else: assert False, "math inst is not supported" @@ -5842,9 +5912,10 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]) # Large tiles - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile, - [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], - [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + if len(tile_descriptions_large) > 0: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_large, data_types_large_tile, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], + [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) # Add stream-K variants (with and without TMA epilogues) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) From 7192f4ab230bb721fa8d4d3df33886dbe86cdc59 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 8 Aug 2024 11:00:24 -0700 Subject: [PATCH 04/53] Add CLayout_64x208 (#1680) Without this I get compilation error when the extended shapes are enabled --- include/cute/atom/mma_traits_sm90_gmma.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index 3a4fdfa1a5..e59bbeefc2 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -422,6 +422,9 @@ using CLayout_64x176 = Layout,Shape < _2,_2, Int<22> using CLayout_64x192 = Layout,Shape < _2,_2, _24>>, Stride,Stride<_64,_8,_512>>>; +using CLayout_64x208 = Layout,Shape < _2,_2, Int<26>>>, + Stride,Stride<_64,_8,_512>>>; + using CLayout_64x224 = Layout,Shape < _2,_2, Int<28>>>, Stride,Stride<_64,_8,_512>>>; From 4e5a8f6853817e6595189e712a8018e1b71e4380 Mon Sep 17 00:00:00 2001 From: dePaul Miller Date: Mon, 12 Aug 2024 15:55:55 -0700 Subject: [PATCH 05/53] 3.5.1 plots and updated readme (#1708) Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com> --- README.md | 13 ++++++------- ...cutlass-3.5.1-gemm-peak-performance-fp8.png | Bin 0 -> 121977 bytes .../cutlass-3.5.1-gemm-peak-performance.png | Bin 0 -> 112581 bytes 3 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png create mode 100644 media/images/cutlass-3.5.1-gemm-peak-performance.png diff --git a/README.md b/README.md index 9ac15f4165..3c5700a027 100644 --- a/README.md +++ b/README.md @@ -101,16 +101,15 @@ Starting from CUTLASS 3.0, CUTLASS removed support for the following: # Performance -

+

+

CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels, they exhibit peak performance comparable to cuBLAS for scalar GEMM -computations. The above figure shows CUTLASS performance relative to cuBLAS -for large matrix dimensions on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture), -an [NVIDIA L40](https://www.nvidia.com/en-us/data-center/l40/) (NVIDIA Ada architecture), -an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere architecture), -and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture). -CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads). +computations. The above figure shows the continual CUTLASS performance improvements +on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since +CUTLASS 3.1. +CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads). Tensor Core operations are implemented using CUDA's [mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and [wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions. diff --git a/media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png b/media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png new file mode 100644 index 0000000000000000000000000000000000000000..bca203c0cb0d376ba59ca332e93867eebe58ec09 GIT binary patch literal 121977 zcmeFYXH-*p*FJ0+b*zjsibzvIP*D&OkP@PzB1PheigcAGU21@|aZnLy3L-VhC{3hE zFCmFaiAa-9NJ6BB5J-SPLP+}`oO|x)`SSnxz8~K8W-Z7%C$Ms^-@bO=JNAZ+`R*Nu zc5K+NVYkJ#%QrV{*o)e*VY~jXn}Iv3{WH$M+s2Ta=9e~w!cfQSTl`M@Prxggk%#{6&6vLPEl;5=3%xa%%bG^z?Kz8jY<_eE05M0r0BJ zDlRTACB7?d{=2-qyr!n6uK8pAw-OuLp z0|NuVvjNWoJQ$rp{_*3-#KgqZ)D(lkn3=RC?eEYZ!pJE%c_1?X`OmX>`D1!nFNmsM78nH_VDVU#_PCbusVQ7pZ4e-zPuO?TO=?J@(-W>rklia} z(8B_CRz%pn}!(Oq1PRsAj?ybZpc9Yd0^x0>%nBx$e%vVC+#1_%Rx zQz3d*w?gTeZ4c_P?W%Zw!sGQ?TxEL^wk7<&I6VTFY!!I^l*qo zA{DM*W(o&Au+zHK&l(0bVeQ%>q9yUbZMv7(Z`u;~NcQ5$`FFfbjD7S~l4yu={qCC+ z-1+BFB^HI%-_lHfWhvd{I96ZtrC@khnTZ$|X%yyfiVYFgFBu>P7qq{l9v28D>d5q( z(fvlSZ^Z8?i^2P<%0`oH3XZ!f=gZ(2+3uw`J~uxQXAm@$$jS0D6dr=&a}VQ*lt=k7!DTwz%t?!JAaWamxKuk@EcA9zsFaV*8}nL@#80LIcktD8YZ z>Pzyw4@vf~X)oLYLt`AzX%U(yt_ZfRY1bn5+nLnRn;2+X{9sXF`3064m2z-IV8MIo zHq&WFRdfS6FhZyS@ibcG%-l&A#q;_`KOc1{(2bdxW;@$26jv{g`=2ssVJsU;JFXC? zt#PS1#A9Pzdj*z|*cLLk1ctUql`&u6Aamhk!VeOX!@-A)kBrfvk zz(41@yne2n^>eL&=2bKwN{ejlrYq$b!Y9d%=}+|S7QQQ z*6gT?$DUm)*c@4bZUZ4(mlR-03y`m)1!`;l`;tl_qY5assq)t>I^NjU!giM(yk3c z=;8(JcCzsGWoZp(b?gbgE?Osksx@k+W>#{c$0TH=BscclkEzeUWW5F`*eYs}riUM- zlTE|E?Gh;5WV+tdqy{yVEHL$(G}ZzzmjGFM(UFoXSUr*Lvg?l%Jhsf~9VVL#MG~v< z=iHTA3Vq0R)URtwh3CsjylUjM3Oo+F$yKOGWJo_&CL#1qx*`GPIMySh6ly9Odm<~> zM-0EZ1V;6kKu1dAj~{+RPesVu+H@3ORCPmE*51l~1QK}75=FZUm$g6cXB(04th#57 zr63F-f7f&}LIT0kby}SUX!VN>EQxTc^g;Pwujwekp+Mqd__9_Pf3xwSHe^>2EVw&# zvfAtY@{jB-vK4%b*iq=EyR&{$3~SILjwwu9OxkCr!tnEi;8OqC(mW35Y+OE(q30gj z0ejOv#dX^Qdn0Lg{YM*>G$=c;e!B1q&UD}H8Ub7~&v0F~ZvtPHZowQzXO_iC+Ere% z&&HpgQz15khuta*7I(oUi-e?ME+@k^XdzhU=L49YVN47T-|r|ml@HVQ)>#hf@{m+H z%G+I-18=);>ACx$3YVtY7#fepFBz_BqpqA3@AP{LMf+*OD>@h22w9Lj4*V+!DPXp*EYatk zz6VeCZ;pAv}0U;!oA- zZrE>1_FQ$RcWQj_;Ea;gFQB)&e z1RpvPeKmiiCLLK$?H*n9v`>oH{irGIxWcW>b=JI|yILy)#E4i}cv+0PPWnrGv3d(! z=;M{fDJwOdzH(dBlX3osJ7AGGfJORT30In>{Yx!bkQJdy@~o4AB_k!L0eyo3`u33A z^1D$=p|b(esZZmsa3a|`!2-4omE^7+Mjc)1@}9)@hGeP`u9`8_V8O8wbvjjFCqYu& z-MXCGuh}p05bI_Aduhb+LCXdot02xs%7ZBGeMOJM)=er5qzpJSYMxZpGufRP4%*WE zP0SD29W;vC_vLlRtlsJA@}qKgEV*GNIGEjz`0!djc>LCf^pb@}z|i|YgYam(^s#QD zjCn>O*!SLb>6S#^5#-}TxQXF1dIE8+i>^u@I9?^)$06vHZf$5!Vm5a0Z=@lsVletb zc?dbhs<|m4WnQ2$gtROR?}6Q0_nS8V|?Z>8HmzBl zu})BjjHEHISh%zMYS)Q(fAB4FhS@87br;&qI#vcdWi`E&Xnb8Mw2W zv6o2kNaEqwzQL{{%i5mk5z&4>g&4R4f zh0_rl(4mC8?-uj$6BnAMtxh$beRppa#utG}6;`$6{c>wQ-P_5=r+dsT+hf7OlAfZ0 zi<8m9iWo(^3nsoJN)DvPzh)wJQvhace?X}JZsm)SwD-?bQ|_?X4Uw%y+HH5vhV>EB ziZJ}MJ_^JrBC8gM**kF-+&Lkk#B|Cn4#Wf+YwJbg=tj%MIKQXTjytRAosY0?ge*;c zY>w#dycgCMBy($>^asEfN}|mZ2Ze@mMaMYcJe+M@T9XD?`hs=94^QhOrdKSBW9AO` z;yeuu1s)Jy3^I`J$IrY(iRntVI;%)*5G#-GlPOF?qsJ(s_=V>F z1{y!42&EOi2P=O1Ei2l=Ayb$xg-7 z{yUn_Ctai%-K*a@go}&n;oFL;VQM=g~tpnsw?~Ftss3JJe9(oI{Hx6d}H)#mQb5HZp;Y)(LGh|2YM*0 zp)Fji)>fT)By}V=A*mAc#)5OR9Ru zB)pBy2pbO4={9%-Q~zu!>J+EV!&n)=9!1{^)g8Y7Mc>qZksb+g-~Sx<_G(*7jW7oj zmR7RxEhO-166GIpR0z@ln{u7{{mBG4hgyi)_90)o$-Ld>l1miP>&8MjbakT3>9_)V zvj(%wFR4q0o}0&2eJ93Ls31FLdF>E=c2Kz~Ez0Z8w-WT;&kFX38?Jvq{>Q)e%?UJO zTJ6%Rt<7qUtKgS-HWtgTDlAU5?zxY-W1aMRtKdC)%x%YWFB{xmRyz7!Z^q|)4}Qv} zGJ4c&;zf)1iIm2&AWA*8f56fpfmU{F=?_98Zb@$DRlr28rsPDHQmE4NQG0g>2$`J} zHRxRP#kU)`-!wp%%1Gx-1AZ=4b@WEk{{{w|!5j?X1ubM}&a)Sk1($=-sHNG85Whta zUb}%e6aGcA5ki&$KL``Hx4H?Btocr#+JSdI831Sq%=%14Xu?{Jo>U zr`y&nQao_3qx!#*N%*^s1S+x>GR#hbtz%K>uG?=Ku1lSG{|^>hgL5Ss@T(aB+xpjS zf11+rLp0*x`Qv&Yj=W0vwI0$KKX?}*!$0AGLbZ*>YfHfl z%l@E|wS`zqSYnSmDL7oq={rlC#)VuWNZtc_gM$?UpuW6wgVWfr^sizs5}aG)($cQ) zT$}^yc1=SC*PlM>C-}{#ZyL|&3urbnZEh+tA**{0UNmnpp+8hAed73E^H5ln8gR@Q zlBs;u71gKiKhekNs-DlzU8&TXK%ra{4#WFKeXb{v>bz(kn%NPhB?~=SW6I`;bt71i z)$rCEG^YD4dp7d#fb~=(#t(eBabA3Jj;uszY-q4VJx}Z1dxWWjsm(&ftF#>fssv>^ z-psC7Roo;j(RDsluUGuGP&j5@96+4gKpS^fyt{^(JAS(5kGvSsXvuWLh*@v4dICck z#l)a)3X88AGOkeBqO_mrqXWtu>7}Nj>c1Q^KsM^!wI%q+^-L&MYiAkd{~0k|rBIp3 zd|4+_&B(UjrXH?Itw}I4!Zrp6 zS4kx8OOjkoQ$6L0bxnK}Kqy%_q2Gi{{23zJ7(|*?F2C+R9tIf3Z-8Nx4Q+LdX~$rj z(X_GiFd=l5KkV}O@cgS>VJUOFAq7G8wn8Pr?wFP#F9kMm3)F;Ww*01oG5~R{DLUv3 zaffgCuP4~P@xyu+Sg*+&BgBwy>3zp%SHM2@t7^VOx?3N@c{cV{uHkd0Vfyz6ACF}Q z2T6jc_rE)n=Gr|IPaMV=&bJo3t*t;>|4>Y|k5QE6iaxTIPeg|==UOijTKyTpDR!DC zI+wH~_HnQ48Tcy@apT%xXJD|XhBqaV-H=}OSyu>UPxXqta>@3z5$bZbg0)J|j)MH`wo&@Y_i-@fq=%oeZ@ zc}H1C`F_(*&j`}1pWdP_@UJ$7Obz7K*mVWA0dn_oQ$br-<9xP53TF@TT6lUuiW$1? z^GkOVR*hM2T^uZ}YoavPs=(Z}!c7m{yA2?8>_3DKy%Fd&WI)?uXeHxLLL@Zd)Z6w_ zdeV-r8uLXd&}vYBHamc?sG*3)ryQ71E&Vn~F|n6n4!&0MKADt8sC*@?s8`+boMD60 zVWnee$phy|?t>f6n@zMq=0m6ctG~>m)ZcZQdyOzyN?F}EFJ+^oMMtoWEIAVp5Lr7hGTZO zc6DuaP-{gr!-<{62mqtf_S*9l z+pI8atBgrR0;Ow8z3dK;mz*nQz9yt#(6yFC&|k>P-C-XNBX_@F|TFOLm%_iJeXPP4T%-p)d$DE7{2 zVd9-Z0ZexTH(292HQb3dm&;2qaH&>CLxVWkpaE8x8(M5*^u&Peqd2P9wpFwa(fxwa z$Z;hG)m1<7+i2?cIQBKWWBiJ&znUs0Fo->_H++XC5oA9O3OklGSKk-VVU*vVyLl$5 zs~DtCGNE}|pS+u&_dj;Yc;kusmsU%cv-*m5XOvNs z(0#L*vD+){EMfKB!bY2)KJ~t)$l7sC*8iEVEip?Uo1o!usuB1Q^HJRKkKE2b6(%ve znOB4T^k^P4`f9yR&w4Dbd8Mb}6d}3sj>PeQLp` z9!+l*$hXXLtZ8YTlFu$p!h(M&ijb(qU|!psske%}-h^rwpik-?`!`amb4Um^t0%YS zInJyMhK%T&qQ)`WEsThV^9YW$lkJ2K+a3xeP$1Hs)mvF(&66C@LtS`3(`^ZnX0aOn zU|ArGx**+RGs(84Uo4z|pMcmMyZqfVbYE`^|2c?0kv2-k&e2M2I7)VD)=o4e^UMVM zr7$lvNh30$t{^C4NUg-DRk>TB09aLgedDlYlKvhUeWTa-*zIwi$mYfufm*OGDGIS% zO1BflB9%G4Y{ffKTG&qEudsi!C-0)_SG2!g>Rk?no~P}1prxS4OiuCONv=_hRd-H6SQO{Ic%iU}qD{NPtiR-G zG;Z)OHf;G}w#g4|Di>0Efnu0*N`cW>@cRfBpRqlThm*cR7kob$`CBCfd9kAnS0T&y zNb5EHz(2&^0)mbt*psR>%dgiIc-_HT=t$9aFD9Pdn$t@^+a^3zi*CPelb5%QFV@V= zeMos7ypBw$jqm0Ya6cv^;-l7KjTzQ{(ID?qSzI?MAP=sgzgftKc?Pq#>4q*)9z&}L zwEKHKZCdcSrZx^TduMiyGVdRW89wU)C`mOA0K>;;c~<6?{Ej{8rrC2(nn4NCICXNb zaax$tf*f0*A*`a|2Xo&TpE>Ui;e}tnTSf0ax9yO5w`6torHog)$undx!OzY87}eO* z_w<{GAhY$PYOcbaV#HRNgk++9s!_E@^-wBuwHXLz*dEJsp9}U|RLJh*pW8E^dCJ4g zhqRk0*$Hql4W_cWW2pBYbkFDcgB_#JfZH1?)G*(LR#`F@?lOl@*Pr{EBUv2Y`JRn4 zV(1poo`-cz-x{`Ah`Mv6{|Jq91!dLKD57cae+eYSxv%M*+d&OzO z$#16|<;YaR&-52c-l2gB(G*p6F8uK?umZoIz&wn1?FWFRG(b>y(gw}-dF(v}!mUNy z`c#HN<5b4FmO(9)vrD8EIKYNXjhFUYtZ9=Evp3(qbglc~!9xq;Mwo;wSP&|XfK_9}&uYAE5|e&l)or)E(gay6{X+ zekY^1c}xvRfzJ={Lh|R&^aG;_Hm8ki8k8>(*3HEh2U?%}0V>B&T!awy8j})}U3ox( z;(v2canf=~b=_j6@XLYr6A(jBRG^;*+?1aI& zf)Qj_wbHXFAkr)ZZ6LDhHSZbxZXR$d5g`993CM z2ZU$ROmvF;@_c$O_@KwKDzA9=(Ea3hyD*J1)M{@dpg}dWX-&y z*W&GYe+6(2P(@q~fwofT?G|Su<`C;5Q|hkWZyKcw{{V>0%S|->E@*XkN%pL6roZv7 zm0m!#tWYI`#+PBxmk$$&Wwz0@^cS}j4HSqFB^ygM;4 zD(WV6;5qD&Tewd}q3G&<4zpo*qcrr@V>DitE$}i%2@%gx73QwIJ?VhV{EoJ6htZww zx9A8OcDeCAwkg?O0E$dXPNd@JlrGmtBON0^+g&sTsi%Nno=vF@see^8@CxbEX>Zgv zbd(&LZ>2Qa%4p|k-{6>fxi;r@!cdn?cqYCwj$?@o$_|x!`7S&D_I10_n%l`3GpGvw z&^_ah)%*BXx=ACAw|h{R!l5YHY-=LfTx*Ev{+^WRZB?LM%dNzU`qhPI%tI~CzGCUU zb%QTszmk0zZT%JN0zS<4i7$Ly!3uMVCs%(>^jS&z0D<0p#@jP7)+);vRg!<@D+E36 zCCn+Ex!+@2FiUi=RJ6gEtM6V)e?ktrwQ;f08~RA8?~s;#QlIDGU|qUbO4TVBZ1B4L z)B_p8=J_{oHlA)LOorHWs{=t6ow#LmQ5EYPl41;rBYm&r0qx-kAb}6c8tYph@=)k2 zRYWfAuC`f`@2vZWkZpMyt;MK#{`{T-uSX=o5L3_`lQaWyG*W^Q*U?pI0O+(_yy$sS zm7fuaR$ozFpr+F4Y~_*{xI=C}f5RCMI+z^hPMNu(0tP|`5A@g;t3XmetvItiX#ohb z@x3qaKoubuw#wW{dZ|wNT6ABnb-tbbA9xRyxWA7qtJ-ASJLWvV1cbg(_*28_u`Li&Cv<-4<{2N8rN~o z{h%C);g<<3pr4*xfw}7~^lM(o?Un7>y2qcbnY*>$uIV2&~e3+F+YTS8<7s z4g`{7A_$QF7v*x?Becse2COU~Ux&;SO-9wbiDYGvj`HU9@^swCQF#&Dp*hPg)2l)faaJPx-E z|HwRIxc%XbO1wP~ODfRFI@U3V3m~#ky{iKk$SbM`wAXx7DKddIw;iHBl?TSG9HFTw zP>bS83J?kxvSy!Vhv@B|fuogy8r02qNoh7^jUoGcD4wQ0tHM8VRd0ttUy-YkiKjAa z|LH#P%G?~(wEZC@L#@iI2lsNZ_D}UZYo))O(2X=sGuBT{hH{U{Gm>39Jb}pBF;z#L zRwbf#tSmK0oLT|-syL2O*F{U-1(61@@Bfr7ylb9e0kl)ifa>CkM-Z=^!#K_Y6{Ega z7;|ZLlYcUJQ5Tq_m%Om=Nq}@G&dQMwxfk5l5FTc(Y#s3peG*e;mMr-xVpvWP-NC5N z0CtM)yK63UvO=~=2%TbQ;qm@X+#>Bc!nN@Eq|ZsJ%7G@{-Qe9}(!O5D!ZC7Ts5I9o1-W}B zJUTyZAPW^jQhu5zH`^2P3EZRd^H^7PD1ctL(E3CTB;(O{BJzRBn>9>zQpuy~WgtBY zcvgTk;8zNqGCm66t@8};|DYdG#MiA>RNAH}InqM6R^W5}z*@$sk-9tdXvD8DHAMbt zG&&7l@BBEXSe(q3@krZAezOJ2+?YM5V(m;rk}D?%u7y{fgdJ!5&XE)Qwm@7K_Y?(pdSNKU&AN6YPS6_dhob41?bF-`~NuT5I!qh+m!Q@1Dq&2KZz@I zM&@GhGl%e8DNQI+e4cvlTBP{QvPpfDuwvUkEysC#QdaS6ml0FWs|>v_pVJEpYM1To z>SgmfCg+{l!?!GJL(B?Pu2WmRx+BlP6WlGwJJS|p7w^u(23=~t#L5HHUX)!Bae(qY zZvEN&TmhWYEsGlb1dP8a`mRhwE}oG7(O?m?IS?f=-W@x?VVJ`Ve6-}2X?mRl;W`^B zM~?K8lV&}g1e`fELuf{J8Ck=-dUdCWquYSQ17^-twhiB)Pj`G_tT5s8H`;s2MSLy% zV!$ZL*e2=<`Ns6rx4beGdPu#=c5~~dgaV*z0JDjnFAk*#BM!oVu7TyUbEiUM-RxWg z&E4#`rHy0E;ibt7?jakv6|mO}M^sF)SAu^;8YQ02G)O#k7^6G?Y;ck1TM%yn{_3Xd z`y^VY`@`G4IB|hxMj_Bk=qY@i`pj3rG`Fo?6biNZW%SS;-!pqcs(1yEh1e8VOGkoDMN~5Xn&k5& z5bnE9|LF7F#Y3BxcQ`% zhwY|o92@0mqi>$kly1coEaC*AcD=7f`}ofMgv^7%?i&}gz0(N&!gTnL*-j4zFx6CZ zrK22}OS;s-4cQWPlUPnu%+~VA4`95pd1WcsYB;%Bk{>X^=euSyqR{@zoRmvQ;kmwX z9iBZ5sDQFFy?;&*O>(9;>SM<~O>-fBV2Id$vjR}lFJ8U9qYS7P2cUROydhMNKrv5aD>J?%#u-OBEy7X;Tc_E zQ6sVr&z{plPo-Ub0aH*FY70VtGfM{5v)@(?PO(W?70BH4V~5;$BH=k)Qd{Zpk1M7X zL0y4&_$?(L4PXj?9=vLuDws<2buIOfnzDV5oo?_y=R$Ng?Tl%k`yQw{c)A%~ z(m0UjGW(9s66}2LWsNh7ty&De{i+J6ux+$`V9FM&`^xjHFv(|aYeO_F^9Y?&Hbd@! zTSZAd_NS*|4U(w&ZFPug@5O_{VzJ*-AW$!mN-k%OA$U(3o2GT==y1u{Ji~P2KQ?GT z(G=$7mz1;2z7O>%7-friC5ie@0jP>TrB-hsp5@P`{P!dfVL}fQd|POuLJK~jT%6XW z_(a?bs-H9Z1E!Y3ntOib5Q#02IuTMPOJuMYF-om2$2_RRZYlU6Sdv-~wB@YOU22+f z_Q46ZlTo8?#2Y24QWPd{FOKaN!q~p0nVUze6VJ4an-ul~)iwa^DEii0diaNS+jSE| zExdVxqg!jFTJr{I^N^NZ-?+HCI6E`!9$4)B_%B*vw$D#0^mC?<=wp3hU;?;8Y_j6D z+i!;WseF%Aq@gPl?{jSu{UjP(N8U4$n`gZF-kOKdtgR=zv{Nf;Nha|w4H;v6g@(?m z!|60U>(Z_m)>bIy;tKAKtQ<8|tWkaAt%zd~cioHnGir6m4zeE6PIte}*HMdDg70O& z%8(07R5H6)d30x&tj|p3o2!j^Y_dvIqSj%I_WacYzm@eJRzllv&TnpHe1~K2u7$4q zQRs(TvxXZB8LrpKtp!Kgf{E1zXPBCh7o-`lc`r@s#I2vxNPrKKDHk|VZ~~DkrQ7yD zrjzv}$+jU7EB&&jSEV>;;DvjbgHLFDp@!~G7Y>v@`x2>{(f7tvmIwzYO9kiDZ%%7a-fv)UYV1Jb& zklsCYA|f(>#2a@XHATHKZ^_k7$xTJB8O2zUyqk;kmjOq!=d{H1nk2Syl^ z{{yjz6FPp2*N0{Ak=YeZGh_#c+~XfLq@cVDllw zA#IMnEdtLywRph|4uqThx6TQ2X!LvRHon+}OFKMyc2pBV*2>q`xD*$0C zRNnQ(rUZf^68p`*>*gv3Z=T!5y;uN@jc{gCFhxSzJa~FQ;Eb(W{#5_H{V(^wrOBKV z#b~MZp5$H2u2+FR91h{-Po#mWeos254lEIH_%du%eI+d7Ff|UkZOy9uuG_4Fy2 z=medhv;dfI+>d_(pkAk$pz@HQ=bB&W0y(yO=^Aj6G^+R31mdcGS_33 z5H6uDX6PM0K=zX5iZ7Cfkpf8SS*w(q5Dlw5w72~R!w!Kv zh>ZbDT3SXCg;!4v9~|L=L))TRq@=~6UNXT_6{tIArabyN9H0#5E9K4AM|o7-jN{Fg zYQlED}(Lvv&Iy&UP=X zvQ?7tQ!x)*_eg8FlU8Nn3depzjv3w&m>sSO|IoPrIHN`%GFC;X(ei@i|R<3s0^&UbeB}bhb%i~w7FfYw6OE zeF3gVosptJIlP9)9qb96?!+B3mZL+3H}m!G8$%mz&B8A=>8R8vKng-<+H*LFfz`;s zqKB7E)J<529mfXSfRgLTE{G(ct0b|Lp#ltfh5t{(*X&ws7lQcZq<9+% zV0%^)+WBuTl)g2*3|YK+;VkyIz}H#bv<1AR3d-OKDQI+kno0%h+OB$}tk8rJxUJxL zo%FY~^!e&jzJn@*XO%aP{mSR)!|s+P??04*PXG$hb>eI%b?DUG{C5d^C8}gJTy`ql z0N*_dS2|g_70b{+dVcOXHS#ZVU#Y>X($Z3bO#Y~?t}iJ!9TkV;S;n;j#PAh`vqK9* zrANqL@M~4t(-Sp|*r9H|&%NKic0NKE|+EZ_&?G+6~0xqNUE;`jb4#nh`w));r6?1&*?&;#yc!iLC5p>uSX1gakrJ)$jBr(uUL^+L zji<+CIJ@~xH&|;@wRVwM$`3pqNYJw5u8cpBPc8a<9qIe;xm<)$cVK~K7HGKu{vp`Z zcU1#$+LbfFm{a>nM__sF&&|LQ+0U#mK;5zi)amK^w$p`{|5tB{5l`wR)=}3Yu=uq}SsLeYjcXZDCAOU} zG;9U&@a|U*up)KLXeafYnZ+UEcw#hrTxFR(Pg2EWs&FXbls~Koqg*zMN_yJr^X+f5 zHHwjjbeLloi?i-{%V-c29L9zBY)qX{E~j-$_pbmxHn!5EF~81lK6bU@XrT3n@eo+; zQE1U|6!Oh8y3g%3rB!8TS7$grNT{0rE_f?Y}v zv89vJIX`+_7wYdUu^9td{;bsjp|%?)a=EN$rNC~nqn-W?5>x7Z8|Avq!>&MsMs#T<^IM!AVntqiohjYbW>+-(ZnbGn$7ZYmay3aYt#K zGRkyIod4M3!U?vcd=O}yrFOY}Xtw!mWU49Ht-MW{;GvCUwJ<h=-*U4$6c_Q4gJg`x|YifNvtFSQwR zAZZce`3ow6b7ei9xPTfb_(AHPj@rZx(xU*mkQDZ}8x|RJ?+Ed@aJUWMjek4e`*h7p z)WZVk_&5$3u(PSZB`rdwf@k0zU(Gee3#WsZ`7@3vPhw(>K`G2OMDUn5Bp}lH#J6%Oy#=IZe)&1JaYTn4 zye1Vc$_ORd%E5#^s;Gp3sLjDVL*=f{bzJJ}xa`VR$TRfR_-zW>x06QxQN^^S?C_aq^W#? z{ySms!S{yYG!ppy-&q=!@HytTA&PUGgA9Hn{F={)KhQBi`diTTW)ay6rq2j-%O=Y6 zUbz;RA{<|-dB(MPPz}>miRR2dW1UPN%+&OCSt4F|?1R;S5phWdi->crh{D9xzq)IZ z1;ZtY{AfD1oOl&SFaO?m028RRQF0rw31ZABika!W!u?ze8TPTORi89(c0@$-Tf1+PZgF8%MEdDdp0AvU{v70hN#9eY!7m>`oER_K`3NMKjwG8o3BoTBt_r?lVRe(%@M?`9OXu z>_dyaP@8-$0o@mAGyk6RN3A;~Z=*$5-|T?AWBNU|bbfrCB z(7KYbYWn6X>OT)nidOhY2{Mb?l zy!XBxgNlufHK|qtjOXn)Ti!&Ja8JmNPK|8rv)XMT)P{V0vl!(P$q(}a&dBqr zwC*stWiLHPOo=^r7|*TFUs50a+$UYj7b?4r3)7eKGDYhgsjP`8OQU%yOMOEZ+A2r8 z{8+y($n>E$&hTGL+CPb&Z%!Tm{NYI1+jRxykD3v@rN&LnSqrd7iyKEw%!7}O7MrvsDitKTb9XP> z^_!)A`D&9&zBUp1J|O!}!}NR_x9IKLOLXpE(=B67b=#)hf3668dTa;6dbx&KukEhORfT zyH&c2DGMhu4^IrmXoT`5yBy6t2wq>5l0nh z+FgF)G>~e$=UD;l^_H2bvO1vo!4mpCzoN!y56lYeclLUJGdyyBz5qJ7i@q9BZ(QTf zwC+UfWnXv!CfbmYl}1 z7#`KQy>#JKqknMj&=Cu{(ZPfm*x?xW`82z2A4Y|tN>B~9;%Q)vJiOCWJlyBDV%%95 zig&=teo~Vn-RWaYaaxLddxSmgeu=Opqe!~+GIiD5bmg2a(VXu+f~3}7wD<9qwl|vh zZGS|iynw7d&=vXdcV$dGIJ<~2iF6z}1h|;0WN)E~LRnZE&nCH)^x8G2N>6xF9+9^D zy_NpcO2o?z(lhIJPn>FdQrXWP*qY?UR~9WNE8-z3iBO2f-`7r#EW6q`!R?GP;m}7- zE4W&+4I=zP2)BwkCbaQAx}74N1g*&A3?F#v#Y|6Fp2vQaCRe+fDis%1>^gs@pi9|% zr=~@UJ(#zQ=EMguyJkKbZ?s;sk2gW08qBF{s_FF)3lwa=qYSSA4iz!Aym{6s#3x^3wYGT5MJ3=D z%$H`)A5rR-`b68Zm?DhK5o6Lx)&#=d)5CPHm6w@r|6^nK5n;qyO_k2V7Od~FK5Co! zLf&SDlrz?=dX{#_WUd+AJRa%(KeUvV`lu@VE%RJQ&zT>Q0bOqZqp`cIsP(jAYsWJH zE^m$+`bV`ZvO3|sdZzM(3Fn-8@(KdWi&`?kB|F(p3Ps{HwF}yH zD!PRiH;EM1Lxy?!ZGS-3fK9%lf#2fHASsGkb56C#mEM1p|7{%-lO05@hB&ws%OLW% z_||H_D&`2Q?#@tT=%#SV8JP5#Vd|`(v58bdp|o>bTBd}ZxqA(UtoDvfe&n##s&m{2 zt_3TTI#gg$M=utHeAG2-I=i;`-C?QB*kkx_VDlZ5d0GZsb_dC}o)3Z1Ocx@Bqcvrs zqmeY5%D1)B;K_DB`a|CK_}vS(*G0<5@E>!!4_g?>k|V+R-jv6^zoh@pb}{2jFe$xx zZ?z+ck5jWhCNewB1;z^JH3F-(a*zqhf1WZzq%Nzl8%fmf+qPc-W`+KyOPmL;nhbvh zR{gXDEo)yc#bDGkVk=>}dg$B`F+@XaRQ1ing!4Sf%j1?>x?btkCO);q}a=QN{0K2S=k#u z8jGzAWO^Z|%dBG1@1#r3eikvhq0i?13h={|EMBUs`>{}-ws1|GMCh1K@I0Xt4)001 zE>u=Bk@JhM+Me!z7?#(UreomO6u3XUAB%_lA2wS=v{dRRljUfH0f(R2I%6h z@YX^1z{LygkA6h|a_`Fjg4-Ua5?CvD0km+A`S)VL{nyHg@k%qXaeRxoP92Q_ciEFQ zS1v@EL^7)4-4SWH(YEsMv?0sHv(jw5G`S=3W26SS?{2ZnJ?FnJ%fBv$BX6gc_uf4g z^A`ia=x_?#b_DU^|FHMo0Zk^|+i>jbu8N8c6>yazQX*YIf@=ZkDhg7g(!10EAq3Y| zL=Q_oZO>_RWSz$BoDh|6-l%q7Io@|?) zWzBAN?q?rgK1Hjy?Ud}yeRRFOr08)?lj|3z>Q|TvnDfA;3QfS9WTu{}ot;G_o;F-8 z73I4>-4-(0!g-qdz6&>#X}?C$1F#w-SSIy_tK;o5BFu*+!UaM$>exI=3O#y=K+KEk zyx;7FO3h;xzcd-_BwWDuALV&u_-I}R`SlHWnEvNqYY$v>U%uZXTdh!DUP`$PO;at5 z8*ad(r(MZO(Knk0+pWSLYif*o%p0U-@MlXrnhLrSa0-drQ2pDW7wt4y`kDOVLcZ@j z>HQ(&aU#=*t(EOcOV8l9k;~hQUx}Somi9jD!SU|=lj=NyO1WXJxw=WbjTmRvt*A{p zZkXo3r^kz}G@WICLgaW!r-4B#_l%y`8Yxdz@#9Piju6;D+%%?~%NDA3M_7&Nt%tfM4eG_!u67Oi3;JxTPSb+6dmtDZl6_p6T?HE%c$ zH?L%SMaNM?iPM-$xV`wRr|gu&lhdg9Ei1yrM*S72X>+2A0sVHS!^yVA^9dsfb1s?d zm{Y{VT%}y~9EKi_fjSc2Mlse9F}ZJcMk z%t^)hP2wlt2t=i>X4NKVU{4>8#rZc&ks{ z5;HbXtC3{-9ZIiwYDG}myv)9+xx4qQvo&iC-0!8XV>hj+-7{3mJi|Wadu2dPQg2nT zy-Vg9`rbaZ<3m=DTEA7{q*F%}W74596Qnm+OkeyuEsk7~alAw6zD(CJ=sWewGq0hQ z>k|hyuXsspqomC4_zoYAh(uv125r)`g-StdEIRh9uV{%*o>WTj96#VcQJIB`Evw6} z629|j@$sTM8*$O9FYmp+@U1Hq$ZdpJVGqLy;^z&U&3!*=I@Km4u1|5{xRQ_=b@y1lqBs$wqhmzRY1v*NcqTZW*&9?r=7 zaLUl|9`ojKrUMH5P&2sH>rojVRl-u^c>&`!<~t&{h1%X3y>_-IYjZ!Z4&~VUc>c-} zOQ`)~tTwl~13}P6eYp@)fuz`ekg^gXwq(izKm^b()qUFqGCK!K0)5kGLy( zS;y6KG4}8*R62bEUxc}@56ix}<>3&pVTnyUv{}zp*o`Gpadzi;j!CchBWJk_iE6WR+c% z(U{}s@xh|S8Yjm1z3~;Q%$h5FLvR-N8r{9OqrLLy)poTHi%x6{3@<^a+&=Q=O^deDaq*5xZpdcQCir@Fd0q*x0JT7!{D z>8b?2DM5D|T7O)XzD*6glPQWEY}{IwSqJ|bRwt`*Qxw1MWYk>X-9?SyV@+KPh-tp6 zf#wyh{jft+bl==>>)&Jm!J*V^TwkiRn8&vL;~aJnTin@LD726nc_>;jiv6)rsQ(n_ z4e`~-kB{>41cGrf!Sk}tr4i>&m9j!HAp9xG?T<;MEC zaQJcy3mpE5T4R&AL}9wDykQ2V=D@Mrpx@7ei`%*l9bM~S4s$)8eQ47T=EZ&AMVaDR?#CH|n$ z*JJXE<@kLP_1*)@cACk@bat+86GvXD$#k~$*w4f5UDRJ77{_jv&mI|Fbh^@*gO~J7 z&$IHR(+|2f9t$3IqH?yp1yRlzWKw;=BH7UNUf&cryvGn&mwm0sUEJ(xYbt?HyF0~=AoIf*NvSHdcWh-J6gXM zW%He^4|~Q#h?A;(T5XKUEW#TrEgcc-AWDOgi1(x!6RLU z;wwyA?CA{6WGnA>@HaobowrmaO{qpv)zs17y;1#GAr$%_ccwEZOxT?dj z!Xva?J36i0q#JSVvNSZ&<+Z|n7ind8H(J(0)uiTkFBQH^UwlOxq-Rn0dq=EpDC*4AZw zHI#nT2fO*H&nNdugR3|7l$%m}Nl0wnB4sY8qqJAoQLE^y+E9Cy2o&3KbLj@#MjXka zQcSah&*|q>8Gjzxa9AN@=kuMRECkGtTRZ zRiY=Ya}nhQ4`Ux#UgxUlm#UtdI&5#MQ)KKz@KjWd_|*9;1bfu4`oq&*I)L{K7O&Zy zhb_#wmu1S<{<9x*ueDCnTu734So7*JK}pR#cPYfGov}-C>!sKm_CR7Iex?@7fR>yJ zw~e)GU48T+b;>t(yFg-}PnZ*Upxp5`1C#5GnQ7W-?YVa?A8*cfE{(G+u!bt=*e%wW zAWpS(zxzwWVdoOe3jbVza;5jEs%_EMDz>l9zawx{4O&J{wY@RF_QjOJ;i4K|Y8~q6 zS6-)sKBHZgjh#{?-q|QcS7M)Rw@(`nR{c<`D z7;wp6)mM13A*`bLL2$8nv2tEZNYzjF-^};$LZ$(O^E>-gMVxSV(Gwlo$G-6mMT1~3 zZ8wTad#Bd=UmI!raml#ScuIn5jK=qthT zol3g?Won=OMZke!=aOe>oieE}5yk$Bf$5iRaJqXdq zX&vzeuFTP|L6Hgmlks;-4v8=`i@sqJDEL>g2n*j;O$Bbh-7dKWUw%1sy9q7s965L3 z^T(K_BkNv+-0L>)ARR&@IDwb#!&NNV30YM4BzP`t_n;qA=tOVSv7{|%8R=T)@r`Fg zr*d?LosLDBp>a=61x*AarA+er7eh0676ve~r_q)M&NlsQOmC$wNf4$8ig7vaSx2zK zrf$6vXX^Cv%M>^lg0l~)hN*G$Rt{%IvarUkRj!GWXRw|oFO^Mf9@4B;WNgb2em1N9 z5UJKOXosriEz*@}w?-^+7odpGKVMuX4j6xJ_;4E!%YJz~q^w0FWx4wwl!!CwT3NRPa!_0Fg?dU#jYgnbh?CXmu6-(>b;`j}-?%Ep+ zLb7NKTQ3itys<*ueEix<=tu%SW{YZDdVsWuB+^cUztC7t6At(k+7%wR8tmrya~pS`zz;ep+o$Ti+Mk4g-<<1tJ?F zuEA{=Bext9INQ6gs)cK^3fuOr6W*^g0g@w`i_CWe{*q-OTT&&W3Vyj$a@6YH5zAQe zm;MUssY~Pm{~e?3*4Q_t(WKXZ4iM6&lXWRh=WQf&iafiFWTlJaEdvc+IWPBiZoP5z zyq0_XVV~H|b!Ej`r!y}M`ExPZ&W1F_RUC;XQCPFpDsYux$whH7K-kdq<@t=((9In& zUs#t1+CLUjQUXM0OSwL!vsvCByL;}!)36@G{;*yn(Y?%u5A2!uo^&2GsOLk%;mR2L zy=LUJ>)R1}@3<4(-B}W=;_g1ZW$Psk0caS3xHkN*0c@81A%F6YKG(jj~2TG4=Et&9_8P&%88A4f3spCJAVl^}r4dTmC4$FAkw$Rejme zF4zq2EFG@0;l|=3JWH^3#=l1`H0u4d4HY{SqwDN)QBBw7%}g-gGdup|Co%`8uCI2^ zGw{*Rh(^M{1a0h1I9wh>ue?kWnqXE{$6*UQV!$yWpQf$z6G^!Ke+5(H5T zM39B5m`5`W?aF#qv2;o&aECBL=(ba-?o03PyFgUOGrl#8H8~pbN4*YJRHQ3j67kTU z))2a7Ay=wlR4Y36*ml~Lqfjnw?%Cti;qth~h>ANi7Yth*NxyE(bfHLCTC_BY2<4g{ zrKdmg{8DNHqDCQQ4$A3efzIK^cEMVNU$xdbh@^w#P!6=SrdQ!aRT|YhU6um zR?tSUG!4sg?VCvZgAHSeZnHtBYNmSt`se6F04)o>EU=fD;EVU29Le6{_q zv@ZR|Fj`>PPP;H;s4m z!}0txcozcwOR^yb-3YU)%uExrgA$HjEDfQ%3GW(M9(f$zX!&nZO75I}_?IDO90^9Y zCrZ;*aydu78#VPEE}x z#LpRa9QC!lCOv)@r29(|j5xaZ9j%tF?D633p|Nr6Ypgx8D=W}O*W-3a55BJLg!F4k zwmiGd76}qJr6h_i{q8K4Hv)1ql_re+!^CFyE#Elz3LCQ_EYzr*Kl@r_{Fy&$#r(!c zVASgWHfrXM)&4u#`YqFz(U&hu)?dC_2R4aQFr1gKjLyk4v8UWP7O!!4qywrvq9qw(*+6@*nMj`T0Uj}be3;>d+}8Jtr=}q#+1WaM!Bp+^aZi^8ZYP+-;&JG1 zyYr4_=HAkH{mDGFhZIWZX6KYQ{MmVrCf}0w&Ith9(1{NGS;PLHC!2~Qq4<#WKFJWh z)I0u2)TQDb@$=|d01lXJIHJ$^l5ZzVJZwS>dL0p3jVDwD5f#7TRjDQN+Vh)CjVzn$ zrk-uD39crmv6{46oZfsq4B*!RX#6CpYgWfcYkA*d$xCSv3JCbOg$vI85h0DXhv<$M zZ^Z@<4W(=QkL`I-ZB0MApke~$p{7QrgJqsXXOo0p+i+?F@8k1?Z4oTu z*{SYDZsVc;0{4|gbOVac88pxCT%Vk98O5I$Si8hSCtMlz4^RE1Pc-t;rl4oL#Lo^r zzz8yCP=8_GJ-`eGDM z>ycT=!c&H&vYYZV`?o;zQ)qG1zC^Lr-=(wiM!p0XVfFidp~bjv6L2#y$64)K2E%W7tjlrUS{A>W z>ebcTo+}f!RS$H&G^v7OI}g124?(?+zbSbKTB8*!*LOeoxhUduk4vm4<;{3#p1!!m znNb0{E;@KEddnd_HoQZ-hbeii&k0 zS=cFk@x%i`R;toPX$6fobYI!#Ta}(W)iKH@6e#7a5AtP4?=gm~ZP~weYIf*;a?|P*2f&Vhv*C!GYm&JUm*BKu(EU~UnEx)crX!mo# z4u)`duwFOZjZOyz)@GWSF`mAAR^(2MMMjxEfd%+>{MMfoRMcPE@;cUQG%nBk%H{0r zrxG16)j(qBt8K+E{n2~NY>OUTzhYZ3IAjY${Z^CjV5}?VEM&QrhL(wQUg?cDowWD; zNy`?O@tpIlrB{pbNAA3+NzRnhhs2SLqN^iJaGnhwzoj+dD)MTHf2EwS#C25jR;l&& z=~l+-dUf**fux>Afx^E=Hoa540^$m;QT9basON6tBFuHQtl?g0*DGkyvywPaFeYg{ z5Rv|Lm2)R2c%{SVsY9DN!d(jl@?NF>ms+c2Y19`cy3joL`vN)DtjhR2m?&>Iau7+1 zbwI^G)XYCrR;-~bs8RIsi1u%v-@2qRetH;mj!r~zmRBedAggD%X9`N+bgepDdKx3z zz!ax_H4;b8CJ_uCPWns5WPA!ssA+kTbFJE>A02$9OhP+5ZpZ$ z)M28$=*%W~Qkg*7zQYKl?sC!#AB#o+%KWCES0#-zUj6%s@b>}?oD-9d`1zfG+#Pb? z5#jx5M{wFVtWRWSHe_AY|M_0cGq^#wx_rfK?fhf%PrKNvx>q;TkKgmRdn`uTtt@rn z1F2}&?8lkMdEw$!JDcW0r}g{%7vFG?ciNpj_=w|4bn5>@J)229{i}7y>lG>EPf;~W zKtv%IJ6mu66{ucK^J279XT)LH2AM;(iX)#Z&u}*r^z?s5#Oa8h^m9>;S4GSQdz4!o zazC8c-*~^8_Lg%h&{}iym%(S>B;=Mc7vXVItlz&UQg38%Z&wkzn3Pbgnfn=$TaHMwfC=uz%`S3oSp)b}+KYDcCI%#1~638za`&%2H)+1z&0 zYeAaGXtbWFuE6_LD19jS)$jAk)x`(gkcm*WyI=d41I*_;zol>=>bHgq0~^AFM+1U6gQiEjG6_6Zvv6Pz>>J*MINbz8>N zu3PPq_X&FR8xTp`zPr38?cjJlYhZ#;eR>aC7z-Ae0IcxPO-+adPTsIAos>b zLA3%50*{3HJwz|MH7(tp-mTzlxfOUM^UKmbMV2HX(MDXfgb8bqZOn|@ywElerAyj% z2N?Oj3TD*%%7-xR_T8x4YhIWaB7XdCD#PBi|0qT16vw8c%Dei*ZJ5w0AO7eEk0!Ta zsl$EO6Q?(3jU-faF*R>@1n(buhS8_Fgmvq#+-Y$bVjoAL3p=7OBvY|_QMB?4!J$DS zqn;6uvCwuI*t!^*?Y{?WMNXougV3 z$u&>%UYaQ6;}au}F+DuQT?{fvEFXKuuR3X;Zxr4@&Ij6_9vIPkkFisqsTf4XM)h!h z?~~45EOwKeI}PBrQoq@%e1uv`%?0H@BY~bB4b2Y~{`o}GtzRFUh_XGJPfS1F#CU~i z8gvP)&&T)u($fU6m^UU?JK>$Jj`&pLsn0{N@w?UgdXw{xJ57T`p~t^BtBIe-1hkD#fkb;E!N*-^n13vneex2AvRE{gyvuv^!-nk z&5I@XQ|b1s_piHh;)}bjC&a9DCf;)~!^M(@q-&>&&l?x^#MjK4P?y?-OI40G)*9`5 zRxQ_gU+;3thY{B0jvd$KV0r3$EKP54vW8@)Ke} zaS?Txoq2G1Q(8rweUBo?#x6@_L7V(4&sgb8&&0TI0?fezS=~g8v=s6NX8$%$urf){ z!;5X_?F!>@9W^{{G%{d8fheo9861fVt{<-I?kBR=*y1#trR&Z4WFhF&&HYlsY_Z=XO+|gfO zwuejXGvbLkcJYVm8y%B)qbNdVZ-L`Ld&O!==n5%yYe9UhdUVqp#uk|7fKRpJm|zq> z?mNKy-PES3XsTDFEpXX>;oK|-$s=@@wF53u5ZOY!fS`&?%8kW|*#>%1M}&5VOA1a= zcN01Vlm<|`XBP_<=X|;Xog#i0R}o|(SJGHtHMVpRQPQEPU>J<*hTk6B8=&duoU2^# zK0xW+JpNTI48}1k9C0hw2UW)H)TX|!*oIj(P{&^x!$TkRJH<^u{@O=aB!iVqo1cet zrKesfL%l9C0rkL7Sa*CbtOFClcLx9_waHmE$(|Bi%rSVD@H3BNifoaKokTsY4Uenh zR30K1`#i2D7K@EPH{RJ8q2Xi0?K8wLA6@0eJVpTS(HMLt#Ex`_+Uvn1WkyCJ{HOYDr^3#wZFenqX^m}>s_wnJ zHz3eBA(QD9JV5Qe&$&%KHDr|Jp!8&QiLni&mrSD=HFoj;rxCW!aaXG@YIPfT3@c!q$Z_l^|hXPncz zz04LZn*ZarX{fkiC-Hv-8FN)W!a1q8dY4?-&f^&(PRvFEzXi&phlkZAbXt8=lFvAd zt{Oj=d(q$4pHv6rI6UAyI9=_G_hUX%sx+-uZYqJ)?F|ScWAZk1J*Yk;yaG)(8R#|N z>ldqh5xR~dT+$4*QGHA!aIlMc6n7%vk9c5=9}-Jki}N4h)rJCZoSFTIdO?Y7D$CGb zbga75LS7&R@eaF?AX5j!1;l-}eWq+K8)Y1AQnviLkz!Uzb!#Z0TjHWhX-wdU3)EL%fDS$%5h3 zFz60dYQ>)AdQ_Mo%j@YBYF#ruM8SIxoh}ah_Dr9sU$T07rFRZ5Bc*5`#xla^a z%c3Q`-}_-}I^0S6WK+#~L(Z66SaCW*uNqZZx*yaXwISDAoU>*jvlXB}0|?pr;$9X*2DKa>XO5pO^~V;gHL7xT_`CEuhpK)r2m zaHr1hn?plM5xHN^5m@D%-aZ1_yYp!N!7>ZeKC{ctqwOs(CMt+U+u4I&?VR>Hm5y`E ze%+OQGLpxnict4sQO-vxRAuA9eHus4XS0Y#|~y{ zyHvF66;T=4i@K%o&)=aeNS*KzN0tu)*n6_(?z7>*cqA>f7rXLB9uP{<0q+K`oI>Ex+u;(n__x;D0S zEBy5g{OeX(4IfbqQ)iohgPjc{3?a=uh+=SED=H?ro1V@7wtlBz-%UkZu|5)85D(Q0 zXz%U5E>l8J^BAMQrrm0K&n<7+#csLxL`zCxaAw%P+$5M=iI1w4SG8kDQ{Yv7%s#ls zj@Uu;gXAwm)D{?r#hwZoG9LH+)39f%a*JKSpD1Lh-ybW}rM=Aj<_ za)9?dtS0ik781|hY!sRl!7$HozE7<-oeI@pSvCr5O(H)r-qtLgoSzvCGl_En$&2NK z+5COh?`qAp??isFOa2P$ZU6-O(07f6Bl7_^Jgj%MXF?RaHjP{YPv@{G7#j_=)-wrV z*hR?iARa8%)3YV(z%mL&MY?=FP9)mY5^GKqD{5A6FEDQJJ?%3SQ_ZqSQ-H$d?QHD` z+yVGisvR6I9yAy+dQ^+Q6lQuUG1gWnF=AlqElj%^3r_GxHhVyB5JCLm0M+^jTF5S+oWwHrgMs*>Mh z`iOg7Q$1!e?CXu8u?57M>YC7Y4{2B^>~W7exNd<&u&J?uakE+3lgEv>vlUtPe`4HL z_J{sKpQ>0rNt9RrOe(dJ<0FF?hW&}>JQHi|SXx3GQDLB&$K( zzT&lH-;1q4%|HYMau|4mGkd+N%ys~hf0Dt!8cp6#v(*k12dFi8DP%`GdfoGVaJ|&K z{#AyNpU(cP)oUL%{br2KG~R0{4w_dfeN41)4*U?tbmk>u~)A0OL7$kS91vdHy&0xK&Q6W3AThn242leUb>X&`b$t^C@R zWdw%&R-elqh7U^0$1R3jSk^el*sANSGu%Bq98mqf>O!Z#J*<$oN*zeEnR?|p=D`;g zJYppHD1$P!ip zETVsYzkHKqI^}-z1+`<|T3ecfIpeKktqXTqiD+UJQJjU%vHWb_DKsl{4$7Q7aJJda;|kWwsKF0}*@tW4 z+~CU9IXQ7vftj3)2V@L*@_hr}u8F``hRD~KhtByqYfWd?erT_1?Pw-ePY#r{#@W3; z2pZpjIWEA%)W^5W+tuf1j*T59o?n>EOYeGvp^?R>bPeaciR$P9wfIG;o;LzzKQpuB zV#aD-)DP(mgoM_(dm}i|T>W6|FGsXhhV2LQn4S%PrchN81S?)Z`>J3Rziu;_B-#}A zL)%SV2TBkF4f&;udFj_Iva~y2F?v?RidLuW=Pdk0Ge-7Ef#%`Cpwwt3k5@*=dStbE z=4;ARBmFJdWww}z#XmA*XoaR8~2tq-!biM^H$IhIZP*K$0QdONA z$gUDy*IBXDHQ-fn7S+xB)_FM|DCj@wru$`xTRv8$>(02U04hqwB3RfR&Gn<#w1&O3 z)LCxZ$lUT=5-Ap5AT1vNIR`>ZE?asX^Cu_d~H5BTZ$`z%mcD}9g;t!-G&5t!< z&aT8dZUn8v4OaD3Ha1xE(upV@6$|d71cd=Gm?`B*n5UDsNFTZ{1Rwf2Ph@8H>kHdg z0fV*S9aq2k|4UEM#r$F)ex*ol@w^PP7ZhbFM1N{;k?d4+8Z0H%g!S=Q^}^Z1U4r{b z!ym3v1obu}uGKjf!8YaaSM+juhR&}C8yJ1FymB&JoxGczkjF~ir5ME`qXYwmKMwMK zmX?ZKF&+itwUb1DQ}LCqL8NebnOL%{CAxafC z8v>B&S{AS7dmxarG0`!LtS`wHu+v2C4N%{3QD*j z7v9a3icigToR?hpoJt~RkKhq$GduohF})yf%quw51NvZI$}6vQzv{{Xb+th#>vYuI zBbYpYP{k#A>y5xffnlFYm|cN-A~GxeFQ?eMh;m7mjH)*t2m`}Pk*6$M zhn05D1x656vJcow+;w0khc%CP(a6`IF+CIo;$+YDutD)BRHfHJh4LP8Y+2J4O>3`C zm}STU7^M5;1ebC2h=kz(uvk*h@$fINaOLcwMlyr#R5s|rJ0Os-a|JXfwqjxWjhve5ju^+Lnih(T5cwwz zhQqs@68X)CMi8s{vqAhK%AMf$q<%-bbxPtMSUuI9oTBJV9BOC?>}_6VDhvEXN(+1t z4a{e0(TnoC2AV>VUFPOg-uIsx8ulEh45SXA_NwouVfBMCgZ_NMEC$Qc0vFRjUclsL zPHE%f;5wLJLJ5N+V_;HnJ)B{KrC4uxp)vQH6Ur0p>_A0~cY;kOxKa53=*Hy#_WO$W z(1|?EM{fB~8_#SNjMwvv{BgcjC0~90x=ZT?mk1w^;~8#T|9qps&s_p2+P@pFKzIEV zzFV+&KJ#QAesm{967Yv1;YaDj^G6^C03C)vE_9$H&p@66HYNn}FoVDQ4&?Y9fpYeM zy3PN5;GZD^=JvnciY{gISKyP2w4*yl|81|&O-jAjf26Gcw>4jq=3+Yj`S0Iz`=8PH z#~A*{(f?m!3|%XQ{7f&t{Lm%&|C*{@SHdhz);|RoHvhfoyh@y{_TByb8`kq}2|Hi< zl$H<=$piq@-|N-u%+-RFjmHH*(O-+xW%*IWsKgo9zt%+Hq#n4*5M`y;832Fs#%UL|aN@K0y|8QFi#;D5se zO9nVCy9dClrjYQC<7=FPnJ#~n&Cbmwu`S?ULxE#sRa3u;Z7okL1hpoqjrpHgbM?1o z6{>5oxnp*9aLD)GYpxu?RwI}a#_vSk)n>N?g&vr+R>Qn@G z4y2WOi6tvz6jrkCt%Dr@eXV0d2lGe$I#n%gLRUjjh`C|26a)V!CLpnhpb>ViOT+g5gOoczI35MyZ-r8XTKy*iQ+w3Hzissw1L8QP?=X6WwV_R%;b$hnoRr6z*`v{z z88*x51nr%rJ=hpLU$e{@yK*sIn{u+};Vr|g#fGV+>dFHh!AM&PNcgR_<6vgR&(uzL z@!4OV2xtfdHK7I?BTmuJ8)-S^ol9H?`ECvS_VQ&tpY)<~en(8AC4jLKS3mIrI-7^E zcIUT3o}XX4JU#73FZcRiRWqxSV=#yVDnh)6hoBgA`F?zu3$N-rd>?q|yEP3w6x{BA zb+942VUR~cDYuQ*O|od@-x==}J%z@rYK0-sZ>;^$o^Eudmz8wxWx1PO=w%f3daQrn z%A3%yp{gCD)Y;ZDpH5g~<01?a9{1fST<2%fvq6NxETW%BY9O(f^QstIFG9oQ`7Rju z`@{jb3Q0RpYzyT0v$f{@7t$>}YHtBT2Oag-`mKWxvCH~p5Llm+&P|*0el&yt?N&G5 zKPp$b;%_t-heREUmDiTwCY5>5e}mnU9p&vB@)75$d+sLwLIM)VfQ(gq&8cT0XZ|6jX{+{JB=!RE(` zOW(L8-oEodVFu4o3qu;ru?82mdLU0kKU zqfl*bT=>7XKW?oM!G@bZx+3fat>X z-qKcrdp0pTpoT_UbGQFMs~;zbMlQtDFg4G_`lyO0LBtx)bQu?)ghJrYEAJaKTm=v9 z{O`d!@ns9US<5Sr|t)p_}MaN$#syY7uGu5-C$j! z#C+^sBmq;D^70r^J?GMGnE%&%Y8gPggu z_9M$v94O;jlvD{ktlVrNwG^Y#U!`q5SP3-eZ?87TQkw%cC4n#d`Q;i%Nu~{=E74sg zW)0LmzA>U=8|2Ko?>YiYgvrwFtI0 zo}g$rg4I9T+AUoB$5?23eZau=m9Bvp{F4GK(E^jX8FKvLcftk(wVMR{QA@q5La)n# z`zWrbdHAEdQk+UYY2aRY7CD&JsCeQk&R;F7s#bOYYl?O1T*+{riI*BMK32< z<6U>-_%S)1x`{6@Z9!yk@yQn`chtF{vrkVP(wf2 z=}q&k88z(L&5-AN*7|5ceBnS=vAf^QL_}fkuX65%h=Qm7JvKBK6rrP=i|x|s6;%Kp zY|R`$;|h978n$gNrz~iRc@OYwKuzu13P5K0TI)wmZ@|>w@)`8w50lxZ7z_s9k-Z*r zJn{$QJ-SgW%m&492LQlRleHF(CV~U;*TR1eUH=?sA&`I0-jILH;2$#p9`PSD_{R+X zF@tZA`#+1qKi&Y42mUdGf6U+?GZ2Ii|G&65AkP8vDqP)bcpqTLuMOdm!9@@SIyR}$ zKl?_dwE@Cw_8mG5ZhU|%yj&=rnMI!72zfiR_UN2c%lG9dB`YL>*uhN*B<8zM3~(`1 zQDz>KwL81@~h+fQ) z>+XoR8GTf&j$fb8C<4%N!`zSwz+1x2i_~*g187?NJ}U2p7ZhNbDsoL}loRCji3U zD*i|+-=gcGF+@KAE-9SZ4*B{WT+vKHFPB`Gv#X~iVMjM?I0oSM1K))rf;rH;!f$JA z2hlS-)dA`AyWY%!-q6%hiPaTF3wS!oMnI5{`3{U~2E6e7QT0Osi033vWe!Oe1Z*Iw z=hup!1en4-&Y+#~pj)M9veJ0D&o7Yh;kB5eG6Jw$IWI-l`H7bwBqp~*0dWx+^!L`Z z2ym71taWokzPTd1@D&dZ$$BHQAFP9XhOAW|P;jW1mArju6dLdtUB1&Cs6`;mNVseB z%T~w@r5`+Oql+9p0w@OIO+QivtY3DE2LGoYKRRF+J*PoDqY8l>`%X}AfFy=5ul@7? ztGtskh#-wIM&HC1?F1@X3-URE2neShrd|X{$c<~=jo?Q|f2$SS#w$OVC;~>u_d!4K zIY-vAo-^Wq`9|`!SV4qx2qCu)^0w#)>gt%%-LM(*dC%IPmNx)CI1QS-xfwOE9`gCr zS|eGF1m@{GaEtCN`1SaA4<~_NDWG4PJn-%Be+Kncjg=u|QXpV> zd)rzAOdoCs$8_v*;ukw{b5H$uXXtqqw8Y-nON z-9iDi5%Sext)>Kib8kG;2?poKM1aL&{aS4RUjmX!#Mk~=z)|`x4(EtAJ_sy*tkVH1qw#+=po82JS)!YG9&AZWlU`D;D(mOWpEN=!Ez>$m;0Y=_5#$ z*9?)hf8-fD5E>@RAL?qP`M{Z9IMvsOR=Kei-ovEau{(H#HWblydRF;|1Og(FKb!8E z<`Fs}vT<(Vw3tV2O?0(28{@;iH6zLF2dO9=m>{J79rFOxTaZdY`xE&7FB2#WZJ20` z`b=9tmlEW_n6^G$>_2yu=@I%stM1%fZ&XCZCdj3OYt;uz76h;QCUANKVS^3%nWoMp z$~aZ(V|cZP1LX%sD|6$2f|R5Yju}Y$?s6phSf($Q?g=}+y*I7 zaBW(Lrz=t_s=l$J&t*%!jR)sM(T7QKZ56aGdyJsLqR&mTbd^tz%VA&dhdh`2!7Osk z_6C%O$TVZPLvuhHYjDJ@tYmy!W@|;2z4n=`qN??sZV)5nN_&-hr_m-nW1wW%i69_6 z?Eb-g&BPHhy|XYQHP2o9=?*Q!4Og^?%-o8ij=_NL)a#!Do%^*g$&t6z%o+#O(r?Cu zY=J<~zpPa;n7AtOMKw`G<-WI1WnilO+Z-sXAG^G<$qcV0Ey(R*;isbYV`D=LJpn ze*(VK*bOt&ZLmE9sSE;fuRc`cXB-Xfl1(+Ul%O3G$d9rwkkBVuvYfd9Y)bHno3&X!%BJCLs9Xv;LmPtE+< z7u=~A%g#`nkb{upDL>eB_LFp2HmIsD8_5u55#~llG#C>S;6ZvYbCv>7f1UdK3wr9SxZ9w+0OLdu%BEs5-hmEy7Ns1r`$kuvA z$}7JG+8>-cwK!NN+qolPxozGmObukx$tu*^@_fo*GC(|=@5BQ{9q;?ea{HjfOwnhw zG>WBl9R7hWGN^C0j!{#K%l-Ar^C&P z7Iy0)OMqte-;_>eKd@D(d%e(}&ab+Nq^i}%&ll_Gc0L(u&;^ClYP_ANJ#EVF-cJ4n z4$p(?gCP6&yInz8c41(cx=Z+TS^iV}h<|!l(6fr7C!p@LINGijwLDAha{B}_*!pvi zX!C-Ip#%$Wyk9!x2SMs-u%jx>?X&ZB*0bRpZ&B-?zm}-G0KY;^-lf9A62+2TgRJ-K zP$E)4qgN4-aGlTSSOLxAe^{pfNwe6s{m#r7i+FkvZAJbQ8-#wJ8EX!tn@*>5F@e0n z{p1f=8rIb%(R}!33|pM^6S&j0nlHcjTkZn-!rx5P*V-Ed+*5P3<~rW7gX@Ro`R(Aq znkg<0rHnOoOp%0iVoohg%*+(v^w}2z%i0Lb@D<(C5`6or8_$gb#iYU1I^W5)q@EIg z1lIMhdTR!smq#q+cU099ycEv(XK>2weFG1vTbMgb1W4buT!4ot6n?TRSb4oXr8j4x zbv8@%>|EyZRFDuG_^P-3U7Y7#D`>#l{l?GjN?j*l*6<+*;;K`g@UAc7oxsgQC&m7D^|9gU3$k`6TH;FArIug*zfs{bP8(k%pxX+YE% z{bJmC#nRuI!R;EPXm?JyBUxUCnMpnf)$lAXLh($N=$^R|nV@x$)U>t!!Eqk`lG~B0 z`E)YPm2KuX9#v7)m6|tS%EfDNCqSOz@&I`u$sDoa2dkCR3N&|-6j71j!r~uL@9vj1 zBt{-OM8t}O4-Y~~I_le#o$P41mK~6{G}E;Q#-l4x7WV$f5JIAnpEkW7IQ(uqW=M8= zhYfw=sDl5z&oX;-vbC|%$tfwdTkwYO0^{S8!~q*LVI^!tuDp(c={o1UJXZkfUEA?q z)wZ;UtPDD}P{Cz#oDrY|^|cE|YC3;*3Z(_=(l@u&Ei4o+&AqwC88!Bva|&2D$?|c} z#PLQ?d*yN6T44SDmM}T*niwL5XUGPLyW#8C7>9u}nm<~mF(aLaRUqGZT9m!2QXBVT z(a`8Emb@_AriL7>&;Qz%@@oFp+~G8=^$Y?jKgqn80A|%%jvN`xDsS@J_|gUa-x#WC z_6ms``seQ7?}z4&HI)^1x{L){%6Sc$yI@aIGLvsAJI^4rY>i;Q14_eB-#IfX+Eq*b z5LVwZIkoYXI}Btc7e@3$83`HtqcD_43p}wV+WG}*f77EH#@Y6mN@H+%`JSAh%!cm`VfS5ESfSlXah@-&1lb|hQnx%Cb86jd zjs335irXP)e*4ZZg4;^VT<^=-srHDvSt($7hRTpkzrL5S+)A539O~hwTeqmA+AOpU zoL}*<(uwf??u*KS_C8OLVks?y<9hI_TzM$xV31buofg1xHOuvWO1_ju(+X6`xF=m% z&oCBFCbIBJHKQ{#$bp0jBLZc|%cj=AD)(wNabhU9?a?2KAa(nA+}c?Iz>)qw>sz<= z0prhO^XtQ60maG9nWl1WOT`D*;Y&picdk%Nt`{3h$mcuJZ8c~F`5&-f!D%G3a&j)@ z!~md$raXIU^g!!#fotr%M(if5ax%%IX5#;0@7?2?y6^w-SgUmN!dAOdD-zqz($-2; zR0INP*R`xuNw=;t6^K>=t7s4q2!tG4XK80^+ER%KBweYn4mGWa$Sp^tN-=~KBT{aO z3P}u?1d>2!97n0plLOeZG(H<-vcDbI$wndcSV(*X#Lyz0=+d5T(ZsMkg9n zyA3C!(B1&r!Lw6L?BqM#)hXy|GqFD9e%?}ec>@{XYPmjJ#75t^xNF2wwR~Dy6CR{Fd$wJ=t+(m!eLVe++eF|*!S01smblsVXiXpxBB7M}Z=24N`%N;IjLuGUE|IF45RuQ8(4i zFyjK<+^@nJCsxAoW8X;%6Sf9ImUGbxedWly`$f!#26k0_6=1sk2cYSs04`jD(AX$l z*$YTg>Bu7xw4r)sx?l%GC}gqrdDYhv_CO<4hIyVUzC1CpzBhB}z8SRPN$6N6hm3it zsF7jvoX}(_EqU_qzd9ZwE^{ZtZSDQeSI7v*!;=Pj6fTRk`5d`%xq5~rDH~r#JWRX! zyGdG7gH&bL?Ut>+u~Af})!gOAV(QwES~;QY#`=K}qpPF)0*OBNopS3|8n*-Xnu(FW zb394LWF~>yAO`M=6$*k;{`l)kVf0ro>BpA5e0>!vfBxdG^L!-uM7VxeMQ66tAm<}kTGBUYl`I2; z5Zq$Q2=hz~e@?09rRtVm!x~UW-D0e0$(H9|dhhRX0@0fXTN9l`vKe#CY%Xr#(`g5{ zuA-1q!-`s%?CT^>$I755e=u@V>>I7AE$I3F4wz6W7!dWMugJOEiV(Bf5?P zCH4Z1mcCu!|G}o8f3jqEzDE1lw$+H9SQY%^_klE;@#@5R5^5aOkTY)L`$o{Z#S@3| zqz#7lhlw-)*a_75`$Q67f%dlji84Z)k)hEESs~3+jo>P-y*a8bers^B?!Y17BX7#b zK1PNqJtZ^P()Xrp7VYBD#4#vtm>4usYJh8Mj>mRgH}sqQ{XX~;&Hc`3Vj z2Mh~@vLRD5Cp%XUZ~b-hTgl<)YL(@I9G&L(nSHZh$4_0Yt~#_?!o=C)=p?({AA;F8 zuS}#dDWScFtL7Kk0a${grHct-vv%?;T-7mcjwbwsFB^T=`odJ=N8nnAX4Q%CS?@6J zC&ASy|7x#%?1qN-X(P<%{mJv$S+S)~-YYNBu*@^APdS^~nA4ls7W-$?0j{g>7%3qi zO>RaNbddKg;V|wO1B+ADR*zHORAy)~hp}6-#l}@m9;8A~AD^fwKnx>ZO}QqrFB^F32chR+iP()#;ra z%VU#G3z~so=AFM1vUKZV`KY>7OC-CtBkrC*h`$@@^5w)UrKHYe#ZOdVTdUC86Ok{v z`(7VZq>BRCGVK7gMV0+t!WyOhCKfX3a|Kw0@Dgm-?+=`nLgo(SbPshUa+b4(Ue`B7}@@@|fw zbW>RrkJ(l0lpkRhWEaWIAyX+57^9=-OA(LRBlZtV!WT8mv=aH#Khv53NyY2X{=SQgbbi}_9MEp_=afp)>yphsSz^mr1+)g%hC6qsc8j6z0X z=n2kZOYdfq+!SA8YYpnURM}Y8cE(#Wn|9%uiPYeI06qcDRzH#Brcwj{!#>9Uw3hmR zV|WS;*zd56O1EJhDYjk{SDo}6Ik7`$;%iF1j`toq#Na(bySSX9{>H$+Qm${f;SK zA>C|8zbGP0*zO)3Bv#=MU5h2VqjJ6$lo4h(TP33nhQ8!9B>2y)j*+7y{T_1rT|ypd@MxL^81(%CHfI6wMC)@Ai;#ekW4~ z_Ip*`ig^7x<vAecSF(Z4%| z#5YhKFu6@sY)Cn}sjR0`B5Q_?tCyFvVBpBO=l3aPgy8v^4>;z1HNX_dE|@uaCo%5x zNm&<}AH=rt`~RuVON6!28u~lh znM8fNP|6v~Vv&q6af5dP7CH9xrZSuav9*rCQwH5H_#dBL3&X|6!s`2 z3%;AUa>Nx}J3t(%s|zxDAGNeq^ja+P-J{QPAfCjg$Ow!#{br*hTxXJp% zoNbdC)dk|a+w#tJNlA`ImlL51t#zaS4WU>mT+5wh zSx4@bC|&>t4oj5EwD5Fx*4i*9GmKv)7)~Q@H#OM|0p+~qAmYem|GQ)B=}kgqc$%Zj z8ee1Sh;uk{uswTJn{^T<1d^zmK8RfyPA-?2Gxj)-wDu+8)+XyVhob|J(?D=w(oPlT zr7z>(v|cHNapLLi?(ES6Ft0siyHsTbJX)f`vgo_13cW_6&4al(9a(Nl|3XF6k z#df7Z?5Ok=x!g7TT5H*8Zjs^1Km@Z|N=t(bVaUoUhz6FRkXdweSq*S6dw%#vuH$TG zaQ5hB>n?vTEwB!k^N)cN5>hxh&&1%Au9q+6ddt}KjC1D$@fWT&=?%9mm^b7g{U-`g zJg6*e@isyK2tioq*XhwCfG}2227cbeut*BH2%RCeH?XmsiYfGl;4!s@0?oUQI<6DX1HcPEPU zf4WutFPtzPqnmSdez*zkakeg#Fy~EDw&l1}UfR>f_U!}DXi~so2;!vWzFW`Sg3T{g zO_zfPsf?(qh;Ag=nHoKX<*h5+BKW4J(ss=WTjK7+6X#AU;{z#)JQp8M%C|Y{)ee;F zH^gs&^_^thKs16>qj5IeW#5J;A%fJisK9H=qzPJZf^QOLc(~c_wk9mLWZ?WDqNc- zE@Z}@7He-cH8qZ=w{Yfxt(F8%jLV${MIfp@CF>ipZvR~kFC*E!9WSaQ)#h4Cqd-Id zGK)$KgBm zy*emvJJK&1CpXpSCEb{Vj=}V!X{8BH4m(m7DG9IJk}Yos140d#L+?97&J_|ZtxZk{ zM(75~hy^v=2GL4o>Mrbrf1^-eBmY{(nA$sTQukk)6C)c;3}BRtu&2Qn*)h=<>5`$o z;WhSPm6tG?KXNMqTxW_wfH$H?@KN1EmcjKhz>4psXef>hEiCk;PqTPyy_| zZ;B@})_)o;CWB!)ZM#RG5;Ax-#1+9*I{_&yGXDIaNZ zgi5AzXuYaI&&MG8y=3E*P@Uxyn`i(9o>NR5zhWkM$nBax*R_z_8%j*Y?+$U6)x%Ni zr`TS2??-|O>)Rpdsc?Pp6uy2nI;il<)7ZtVsdg0LZl!F;mU7{!PLlsbkEguIDStWg zMbvCxhmb0PqVTj(Hd-?8RK9_51WAQRuUBMl6h+u77K_L)mY8(MuH2&n=*$QopM6ANEDOty1an@&QNPxl(IG& z*{?SriKEf3P_~gfx;g41+51}f$oe?SvKMm2V$Rr@E=oDbS40JcM{>GgXMFG$g{>Ze zw3f2um8QXXo2aNJ+-NMVgX==5h$}KOvmB04|6njnvnOvRIkC z@}SsYL1;|mZGW4n)Js%|w#uza-MrXXDLuT}?(S1N=IBLFPVRdQTsfzzraM;unyEH< zID;q;<(3BR!8&klBcsvkw$>;b8VugC89R>KqQo<4A5k}I*|3RAaokwWs@64pg(O|T z8YOgM@CUOTOr)Wpl>WYTCvM;Cr{~yT1zq~l_Y^*ZEL)5I@tA>U!6yn!PIqr+RV#4q zHiyn&@~!(bOK_~wzmeUpOhfexX%|1BXm+r)M@&8DLj|?AtEGl0p>EFAy+#K8hNi2j z7h{X@dJOr-B7FTJ+7r|$B$(heDF0l!E#B=VG3N*jNuy7rtkeQnxW7dRE%s}<*x{8@ zlA$(G7zz}RPRkH#=vpkMTSu7p;dSNh=mvK~pSRyeZhwLDRp^cq7zc~L3Q242b@D{V zy1HQxCkzfN-+-(#xmSgCSG6`Wz_ZJusMjB zHOLEcXtZ9+{tZl`PKd~)$_$wTLXz}tj^VGGjE<2;L%l6xAVR5=&Y;b0rYr!0R7uMxy6=Pk}P~fwE#?;ISxyT00Vh_Y=v+R{45vS=Q=o+E`1$p+hzm z7AL^ux2z|fj+4*P4tnjR*etmG2%HzC#eQaA8nQVnZDbx&%KjqUBES##8h%lp#iBRC zh`*rp3HD#AYU{P)T+>D$Rm;ID*yQf@J` zszB1s31_3d)?M7anZNKcE#pX7Vwu&db)7wqE%=(F)U3l2GL^^Kjbv=Tfue{ZiKr81 z$|RbX98I|P%R?&8Yn}X-)~Z8?nCnjLTglet%2_J+E|nxZOFsvO;|-G^Mzsk)Y;h;t z&E9XynhCk-bD4J@awW`vD5w(j8bs~C8|e>^-I08VjNMkYt0WMe(qV6@(C=H?4XuRE zJ5t;|+Ov(=QM3f0`8HHXHW)Ff!3eTJBIkT2==^mt>%^J{>5`s1#AmGYTVn7f{U2jb zu11!qbdlvXym`&#SV5&8d}8J!6QUp91Dtduc-@b#$RBzQ^R}G%zUG_m>+?*EMs^9S z(s8_0)Qrws<6h4%Aet@U0X7(VKQTHUuWtm$!}}(pRDMxG+`}N89-IB#OYiL({b82f zqOo*xVG;+R!+i-wD_V=Gl&J6J`kkQ2r=L3YtBn!SBw ztn40qKrN^s(G0Nub%w{xZSwDmRde5 zFpPBM_-!;@bUsVkM@sSoHmu#Hi!_|HQkK2JX%ghiGBe}ziHa8DWC`c9%wvr-+Dr;J zYI`_qO{jCNGIB9n`d41E1A7DZSr;Frm?g83p^8}VwyR|-R$a>2`++JGy{PV$=2WCZ zxr27_UCQgXjSMdped~~}<7=dZAx4{wR-y*N2?F#92_)9%QZ)Z;6!oR+>c2KwY5-)WFmz85~^M^9E$1`>c7Z7RRM=knj@{zB3_nnW&;@nF->} z)va32%FJ(xOTCyRWNbt=3EiOM_Hbl=wO#I1?OV?W?XShMzbs3$FxH^Y#hWv@maC$s zmSdV{Lka~AcjCl}&89~!J8`|P{3scZe1)Q~(fyF;d?DgH3&kfSA2_ zGn%b0zZNwV%Nv!ClpBr|h`uk(m!(}Pg&iWb4cODqXB37+3IkLfeFyy9om@ylsu+*P zoV7Tvo{UQ=u5$3M`w;a--q@40i+`r5#sNUUs2J|Tkh=GxU9P`r*10RbJR;WC`1|KL zeK=RmJKZ@*n^e^<}r% zj!;a&RvDV8j?)@g139WDEo2I^i62^)htK~R?O@;50!3nIIP$+G{y$bp0=2f~Q$_^lbQ-8N4b+;bSN>C=Ry;^Q z+v%hqwX8d;9vrH}@<1*pK_r~8t2yp_uM1?*F7_&5Q4*^GF(gk#6RIE9E2lIkj+pL4B#fCF6~Bf^zZ?PLr#L4rVzm(*!SzJgE9dt zbfvfq%_9Sy6xae-;1Uk2dMqx$OEqVkz+U%j$O0UkWNgj5#(_)gkCr&)Q-S~>a5a(| zhxxMux*$|Ap7^AM3P?C*B?O?C!Yf-5J#UI=0{Kc)D#DXAio!_L3^id69!{KNd2tkv z-_+WoltI{G-ef;i#a$TI)Vm&B>{~dw6{pxXGOQ#DCW z#cf`e(j_x!N!}AX(MiX%ZwV7k%T(Dn(qB~-x9zs2kU}7x*bM24mZ}yRiCj|BmQ8S= z$KkRWL~St{j<^|1VS!C1h6cEli~`&gOn#Q50)!Yuv8nmNd%V{{#wLxT$pY|p$VL8D z2xb{4Z&=L-1T&g#HR$Pdb*x10fA9jHLfuV*2_-R{FS?3F2X6(A7DCNlDW8sMCYasDen-2h=8z1c&Q z_LBVJH*ru3%>|UttEf3KB=*~O)O)_H4{8ilQUtmSyzy{A(>+@nA;L#{+sU*5RZ4F_ zE-ig+si@;+nVpvjNKD|jUr~5?43CqYD-4F19r%IPr(nM_=@m1s?J|^c{D#6}-RTuUv)P@Gi}jLp*95zO%6r;E_}2Q>tZy{1bFeH3TQ5M<6^7D@F~S zgLUbN-egkJB|V(T67N^90t|oB=sp64WrRDf#I_bTE{`J)Dexseh=^)W)i=laO8W#IiX32HNbCl0bXB0 z;`Q|d2?D6DGAZK(ZeVZQ7C%lEI(SA>H=e7ZM5oEUeMePNN>wTnyoI&7PCi8h8;nSm zed}w0lQDs+!OEbnkv(Oi#v6~!8=B4UkohC6+;M$QJt!a?dbg z$<$l~WrSp4q}>qo$r(#V=0*}&LrLpHTBDFJBWRX8D@d$o9z{cgPn6q`5L^{~&K7n* zYm@8EYXnGXcwZ>2I|ZIf#$=nC^g(*Yd%Q7SiVTz-b*rJNLO0A`Ylb5XH4`1T6KFd4 z2UPw3-zb$1SS1M_l@-di#R^j>_wO3U~4~J+#8&i@&icJuyO-b1t*3;G1)X>Hv zA+k0q*Crz!wWG#dVT3BzME6s$)A8Al(k|9f$hNne1IV8|k!3V4*}9v!RLTT5A(o;A z$3eV%z{ zr1bCE2})zZM3E%ZX@^`hVVYF^cZ)DTe-GC_N$&xjclq0nX|T@_xbc$;0|8(%_z^wd zTY6o*$+@`mq%pn8owzIbPq#%=&HsPeVE+#<&;P^2BR~`e>T7UzBYEN@Y~qrh7>}Us zPG8d%7Ix8>?#piiPi@jcE(8hYpAVL8fEbx`pz$sQlHzxs)^Ng&iII%*L;MkowJ(eG z-Km&U0EF=02~Gh+nlE_D@CqrH)K|dvRH~v-W@lC{Aq6CXcf-3Ma$W^uK=2 zonivu>gGbq{AnRW1KZ;boe;#Gl*I-gVDssXCVW*DAK}#r#$F&jIBIX;k~m>B&AD=z zVuxWFv@_K^Nfld$`bJ=DRWKSzhsu0Y;X*3gc2^XT;8bP1(nSEah-xf1K?hGe)kPKNSB*Zru8=G4M5`$t?G5Qhf z2|(cPQ&Mtykc-9m_zYd^b{T{?ms56T;J^FlD3g<`cA2tdB&v_{d<>rZ)6GX231a^D z>9D3@6Ri@q2>Q=ea$7hLTLu++{;#Guxs!tA`zH3mamsWG7!fgnvvv`mXt}3S$>ecQ zj?LZl4EQ1iQ|wqPGOe9kWi(piWK`AJAmnyyZ)FY-3Sdw-o}eiDG2)#wjPjm#4v8E$QgR&tecxP@ z{2s5?y_}3IQvhQmQ0>NY@2WX)nr{j5O~~Bl9GhDwM^th-3-X} zNYnL}BP2Ssh(a$7aB+5Sue?*lhay%~%Z&nP-J#nb$;ge7OraDkHX}pC%Fe8CI=0Fn zRCJosIPCEFS?LD*SxbKk`JBqc!{!QCX(`02nQwNa5e5TqD6T0^64r2uC@v(RYY zteL2oz*A^&!ab#32zx&5@Wk%H@p<3IF>_x!O|Y6~d_jRsTtF`l!zms(PdwppHn`Oh zUg|U-;%S@DK{&8I;jB4m$N^CBHB}e+9nQn;7Wxe1ReaVLc@xFD+YLr;w>eqhdRyQM zj9og@c4_LDEc~s!5SF@t?ryGkB^(Te^N?`Bdcrwp8yNfJ++{|_a63Ru;VAcnGwV)> zai8W}5?HlJNO;eg<#QAw zk=hVIS!6uAbTnv_#{BSw1yc`l91oX%LxX+pf3fHS`cj}x^l&TX{3pQTlFg}jQ6+C~ z_lM3ot5cBIe{p&u-|3z*G1?I?`z!Yu5J2^OrRcYrMIL*(2fj=4zfUVU7f@d$b7c`{ zlf>eC-x=i@8CjNhwLvny_ssO$u;$&IM<=Z}?0WYEyog$*8BUSZi(`4@eUdxw^X3x& zJPwav$l|}5D4~Nz94`OSxO`S*P&%Hb~E0>>iu8($ODrB&y-O7+;kWHT2PDna%DA6LFQh0 zYpL&!aITDJRysynT#1`KXGJCxJxD_N3a}a_ z8!5k&At+FnzaqDKYmtLWRXc>Zhtf5K0}i7p^!iS;A8%1_n0pM9j5dMzzn6y@g& zMt2EPhj49@<90{Q(7p}+1+Is#S)J?}i%)w=`d@>yImb;Zr!^K^xBGsqs?!L-jZJk7 zywVOp8V;+ojvo> zxqY8s4;o*7Xs|tC1(J zT{&WSm0>xu8WpNo(Ygy+9l%gM<74WXwG5E{?-M4mpfquq50*XcLqyXzGRDwrV&5Cp6!r6WM=n zXtYR9bYqpDlAuCnq5FCF?CbsR-_Nh;ob!feyfLD^?~MUBv`W^}+08u&g?u)7#~gHj z|ME@U4fhXYMuz37zCGd%jmItQY_ZzMW|XjSkY}v-8$I%|(uehg^NiOMHXpyd`|^B) zL)0^V`lCg?Jr1|?{Pu-c-BGsq!Mds094A05s|Q-Fd)UGf;rbAsV&?D+d^yvzR z%Jz`#w*3*mYwpjGnY%~b&*p3yn)Ua)qjYB_|9up>=TBAyy#xi3%5nl;Qb&z>Cyu9c z_7-M`SbO<2Yog=TFCdE0{oMf)nfsx@?lNG5DLs(SxX_Q;{#MeU0BB;c1k#`dhF^Yk z)oIGJ{Dex}q!z5XS0$-SnnxA~8ik4oeQ^iI#?C9@6f2V%jD1729NTe~OIp+fMi03yY)BGWKaoqXZys>$*Kd&&^= z;C0v5*F$;9#terKm3TnDmp6LAXE}Lf0^OAZ9{zWSiVX%XCpwO_;a9=MbJ12 zW|h*zzt5~5IwkIFs>0+2n%_Il2~x*K)~ohx*9m|7-pdbY72n)q^o3T_%wPZ3;^h;Q z-E&nM1K`qpO_;6U`TYjT>C`ksa3HE%il(ft`(<;-yYV5jiWYyk?Iq7LMEv^8a!v4w z*j;Bk-I;I@SoTS@W&UzB8O)WzdU+Y+K(JW2DvHyZyc$P|)!a8+nIHq4-HtcDVr4D6 zi4?gVoNC@!{s&RTj^r}|Rq8vgk)PR$sx${pEy%=@Lz+UHGe!p{X&L}!b^ zr__IaHF_M1GktMCMFJbR=E_g76mk1eI$POND8SwadzQFwxrj5c%3WWs$^0vp`wu12 zoqC>kqS1mcMMbM$APacimmeU2gk9J__8Sqy!kv9t&L1g(27MLoKPHrZ-c*UXKMmqN zuJJpKrV-w3>36|3KWqSmy4`6y;c1wVP`h`MNVRZhuQJ@VYY6P5>u)#8-cq-vuSF}H zTeO_lPbgm!wxDaXiPjfPvr<;!N^1glmP@*WtdEua(*}pKogvF+UADCs%xNxXV3!mz z2!e&UX!CKQ?z*CLOpTqQM@MV9%*0OTx_|9+Bo zKU0iVLW+~E|8!3wdQm%YIcT)U{t5STmtjv-lR-*OtHt*vEpNPCzmJp>rmChu~ta)lA+1gvS z_1Vx;dhF_8LG~^0q1x)?jlm}bFB=S(SuE#T{EycLN4+(Rj*=DTuKu))eF)1dCKf`~H7JR9zXgZbdr@kQ34Ty0M`pd{zzs-|)UqD}c_kZaCkbl8qZi&e{c=uJcP?^E3Z8PpZ zpOtwqIDDP`vmD6#deWp$o2Vh7pjX~!y#pt5T_M7%P>=5q04szRsN{o7I4))wCieHd zI_1=4Q%ZuTrBLKSP&1N)n|GN;9H6|fkMoCZF{SDXB*)+8W3Aiv^jjQAEy>} z^|4QsrxYeUfJCs$Hmn6;gM2_-x_06}HnKJJ2MubfTAYB+^k7HGv-hpm;E1v|V;7sF zgN*B;L%3q}vYcc5^g-YK8!!lEPaqBBi6`pFgRnWck5eLGp5TY#?R z@vwQk2beVYR%ywPV=3qBIj=cBPyY+gMGPu3svs5aK>{FOBrSPGKEMnP4h}~&qEfnL z4WN%bEfhE!soV~*nR_Hqg{-arw5_{mzQ<`n+>WdITI

_ONvW@)j2FYIdMU(j`H~ zs>NHw`eHYjB& zQuG??(m4D4T6CHXj^qA2JZ2#thvIfMgw9Y?zd@L8&KT-yabTCTnlF0H6di`Mq-Rj< z1tE(K5c2kqOS~gHak2lQ2ldpk6-i+Uy-Qj7npZ=DJ0#-lmfTj4^qhmD8%&rvT&NfR zBfRbcfW4h}4v6O}Jq5ux+o8CTu4C(9YzCPVwl|IFn(PmnPg=W%iQ~n^a}e{W<5o+D z>FIe7>O^aB6mZQML~(&lQ>6K;%^FZn^SHd1OmO5Bx^WdtRgyw_wj}m(OQPVy+y{Lm zy`kxbMO`3VE99J_4;3Phjy`ZcM82#*$qaGEd?i%*DAmW(J?BT`x;bLw-dbFrB@^rp zsb&^&`d8IG7N6i4t27jsayD?b%=MJViC<_F%@MnEO^MBngFI`pdo=Zfj#^h77vE;< zFN#m}I72An2>0i$uRJEg`ui@oI_8*4B1FQ~4Ia7aH9&SX=Nyy!Wh-p)rczGOiNRWl zSWqMu3x0j=j~->aK?8_guJbH9#y5@VTHyYIYHP=991fcvROif;UG-&cW(oMrFs8%) zHdI%&&9N`}wa3ZZo$^;TCAJ@nmDS$#6BQ2hGZKu~9(&Mnw}2@M*y8(FTg`(L)WMy#|Iqr=8QH7xXRnnTDuuFU zWUvE{a3DwxHnz74Wv=+=Q8NIY0v_5X`WINzDL>>~g`Uy2y5=)j$HV8oqfg4TzK!1* zv9>_5s6z_lDB>zPD)jr&{TYl78HWn~+vxmZboY+y{abNaY~l9(aLiY+Y>TckqpH8x zT7>n)$|8kp>A9;}u5~Qmyh_e;K(f{1yYC+LJ66frco*w!9xEfS8ik6f&8Ba%h8A7q zXPqOiSg>2x-yHY16f5d=oBg~!;@~#JosA8tf~v7u%d-ZgCV_+=9>Zd0t3)=H$dd7f zN6v+6rIJ#M901vdma~p?JZ9aaI_r*7cfL$e2^UN5deFzFl9a3cDdM8gLiKT5n{H&p z8j<|h4*s*F)=jlgXxT%CKQA505AV2h0&W-(3$%AeT6Im82BB}FA4jn2S&z)h1AMkB zI8L9V@z2Z}xoXbNOm;k8XYpFX!{y)qbWUSAtT_OiyE*?tRvH=JMtSPS&2`P?ffmNX zB3ZAosR@b{4EuuFY`-4&adR+tDLTv8+9+-L)Ylh%brZrLO>UNH`2aARR{ZRgV1NF>SBTc3mZ?{8vN8)d5l9<&loO|3bi1jc(EJc4+RBij~oZji8#_Trc%%WUpXc+2~HkCP|h zKH!k!G5Gc3jL}_steq(v9H?&-O8q|f=-PbXTp#9XWDiQl|KIN|Pp9>#DKJfeX$nkJ zV44Ec6qu&KGzF$9FinAJ3QSXAngY`ln5Muq1*R!5O@V0&OjBT*0@DrYSH@ zfoTf-A5Vc2s224O;$GtI83p(vltZ-M8d7&2!4dP}Bl9(=I!op@tYb4WtG1l>(~o$D ztGx<`v1_|G7iC)evt&;;GR#Fq>DbjOwyK@g*W~%rrQMt{z2$0tgv2u1QCW(aJ1xr? z?_$I*{Bj+0p~q@RfdpJs)1kYNT<_%X*&F`tskE^)J-4D3(ws3}Evr82`bwZ6J8tKUi=%|p5*$E7mn!U9i8_&}&Q;&ipt2vMz}sWsPtb#Rwt z|Cyz3S{d&+RotU!Xd1CzSwR*Z^NjVnV-p3bD&q8e#HSU~-N9King>REeF}04yhB6W zv%acSsVtm7J$b_8PypU%F^k|5?;7mzSBB-B8)A;#*lrle;Uy=p2+Jz)FZPHWBViDQ zXPy&j-m-OBPpWJO@&$|DK`cp@k5xXkz8krXPU}_uAH3xrp$sm#T;1MkVSb2N{!Zum zu^Gaq_1hyOOQe!;UWK4GW$f+ao`H>r;PP7ULYeuO*9xF+QtdM)E3V3wk65+Jf<{YL zeWC1ilV^brJLuR?^t{6>u4*hpI>)Wh1=c#eWpo!@ew4MFnxMLK)@bKGhpcZO*qd zlh!TX72Cqu`pAQ3!G8z41y!4_we~l;hBHDc62p%;EKs`FDe5)M{46a9{ovd$9)$5} zRKeB7u~pkL(?u%pF9d;Mg8jp|yvE>y`$6Ln$F+J_&q|;1&Um2{W?M= z=jfChLbBQF!yZ2aNDs4mEH8Nkik4?P_7;DfpjuS#%ML2=Z!F*j9QG&*aP?hy*`U5D ze?8#mdV?&-;z#cc@l*3pMTX+ z%&^$@4Y7zu-|*7}ikUtKu;b5Zv^NZhYh*?|?)@=AFxpYKQD-TY!fn?rE!Ir9ux+dD zN=r}pymyZJ(e79-|ME=1L!U!&#-G0w`B2b#B4V}-b=|fIbv{+mqx&5#jfPfN)pe6M zd%lM4EpE}&o!)Qhjdcm#W16GpPhR`^8<8_QaW~P0IF-z=KE2!08)qzp`Du2`21fGc zFlFu$*+!Fnm3zdb0%7rCZ9p)pe@ zhZ{7bIQphnRdh`e%d{&KlNa+J8MWrEVAZ~_V_!f&FPB{M(_mdppGJxj;y{M?cP@xbpM{YrAMw7GlLJT&fY)D78f~J z;VplNeLN7qea&!Kr?NI$?3)0ex0yqhn0x4)rGW&NkIx_)YQhRE-DO<0IgYq(eJVl( zh*o6ub%B1DV}Zpvd}i?mwodc=>ZJhkyAd`1)sSnLtcEi_?CW63XI}@;yH5HURqRD>Hfd2`lSbw2eX6 zac1@b78@#N>58@`BUN&*B~XaacS1FrS_D;huG{n|r##1ymW-7j6|*@V3TJI$e)#k6=v`skOZGHhwrP44y<7Jq&nILbBZv7z zi}!d{zH#ozhFV01Pm)KzNjC3yT=JHLhasbvoZm5HSlPLaD(B_Gx>dTh_o1=|%uswy zS!n{xMFgVHG`ERReu2!vWIwBtZ;#|yl1Cm-flDE@r=;Vw>BMf4g>c3xb!HffDof?N zNA11*{aASR_pP0!SeAb52B)je{!5jmW@I06-I}BiptFXEYg2LN&!MQ7SE%ZFxCZZB zWn~aL7B|#iwpqQAwe-DdJsFnKK6D+XO9w|+FgQ`?v3v6oC*n+K)m8)eV61ixAwQndzoAR zaoFePkShpZAPquv@-jnHfvF(Dr?w;Jbgav+J-m4gE{H;8?&c#3cN$$L7S}q~va0Sh z^tSbPh?tvC#rfQnU|fg4m`hYTcOjYP+uH=mt@_fkHWeLN7kwO_w+Ajk^sJ&8S!Hd# zmz4+Mmpi(^o47a2m@SEaN$|#>13>Y8JLpoM5@;2!1huxq`;{!j(ceopFO{+_oFOGs z;%?q9t~()#8QxDMi(kBPcFyntBGi>1VUYx@JD6FSqpVDC9mBouYF^weo5-alzSFt& z10S#9xf(d5E-jr=KBs(;E|WAPAF|4;PBXihJnI2LT79oUegv&&)X)ZS>zsl`E+^A@ z3u$j3aGs#g%Zu-rwGRKKJG0;&JXrbrgp1It(iPC8~$>{#wZ!g54I0qcozrr_BG)HqnT}|8q&FODpy>q^)KAHCzv0 zV*VKUldQSCO^+Es;f~;$vlEGA<moeN8(pNdG$I)9bn=>42P{5vwz0u40cg@>X@Y2D1f7 zmD%BC6V19{Cm6yKRN2GKPt=XUP09TObRDA_ojbIHIOA_=RyXA~l_@xOmyKEFXsC3` z?eop)VI97`=AEC}5UA`MosA!YUjteWD!5OEE6SU6?mNlhGtDJQrQMt>=`Q)gZ+LF< z9?E{L%LS(w!He#WVGCG`-qgFVWNCA_Hql`iFyi~B-X>bc;@6-rHnXo>$MCoV*pRK8 z^SI7`N56|m9?dafdyI@DBjxT561O-ZYTvMtAveQ+EEJ9AEj;IxpG_3dM;mRM10*xT zX4yNC5{ZU$rgm;8uJwnCoh@*KLXcYVmdos|F*4=^x);0q2y?1n?{``sI@zn#$f&7N zA$2ZtnwMN=sI8xH(Wr&>nr9PTwrjx;{CS6be1vj_iBa%Sw5%S!8~=M@687)_adw48 zGr~49*t^GuEKAaS;mSN6sFyYqw|g+Y>o`win%z{nof-xeWu zG1^C*eFBB@xYauHQ5-zdHmnKCxcg^Z;LHFNrWLmx7EaQ`$^2(;i%cS{Mnp2g=Hr;u zo&Z;6vxwm#7#LcOGDF0OU6u+gwfQ(l&5?ykhyg8KrsqPj>?<%KuXY#t>Wk2Ct)}}X zqmJxivmhJpNsd8o%}oNn5=-{QUp%?s)ZJw)a;_(4rU`}&C+V^t&g1b8r^|pZl^m0$!!^M%GWRs=A*aCnaR*ZfnCjiAr`tO>E~?kzS+ac)3)10iiEhK zXqS*rvIP;2q#?Y*g-IiEKNuM-sQsI6WMDyRI&yc0Tsz$DK_4eB*}m`|RZWbldJZ3N z_-Wm=mNj=*O$wKR-R&p{I-0XBKsHWA9ePWm=4`n~#bDg~N35uoNJb!DIo> zyw@d1!!bA54x31R;limfy<0E>j5&v7OgY@exo4dAc^W9ON=Z>}N<}D00dT=m?R7jV~I*PIh*81)-I-EI!A+__ri66}Yw$R!dKIRREK{ z%MIkOqL)cFG#MEu124iAFec1dY^Xm>{_6URQ{GdB7UEumZFkp*aN-!gl*LDWp}w)>5`Ad{A_2~g&ZfU)6G@Vgs!+IL)qSwXHJm+mdk4uP?o4DqQk6MQJdK*9 zm!xPLG_9+IrsokZSe&n&^0&tuzdM;BJTB!B3mI1!<3FeWf@EG|0Kbec4ZJxBa%^5; zWTfkfWP)3(3E=$1+gg%t65TCQn(EL(ri3oOCtL7et(Y#xhr?LFF(1J=fJthwGSF9Y zvlDK7#n^>C;|y-IiE6!X8h&=;uJD2w9%YcCd{0AZFCiCWW#c-~Rp~5q%8$v%brC@% zW8=frK?BRUm%HrcS>h@$L8>zC6E7k6sP%C!cOQcqxZ4=3Vn#ev_S^lAP z;4iXO)cXry2l7=cn>M~Qa(oomUUSl_|J8hu`d;&(BT|qiZ%x5nB@ROrk@Jn)PB!Xl zu(^H&O6`LyeqFb9zUDeSx9mWslGs(u9=qns~d;+k4Nq* z6<4|wgieon?+qYzc?x`ad0e7&&)P)mgY4KMkjwq>>$on=(~kd2>gy`FF{rP>m9v&{ zXk5TeV?j_a-8A#gKe^1H$za6(#AF~Rx5wpVyoVqfq;`%Mbd&tlg^jp=Am@^bRN2QR zruz6}eXM6yv3EE30wb}So$mSN+k_!R7eBGpr`@zEYBIgLE{O~xdfrS}^ei|(QrfH(F zeAl?r7+2m|}rrmzD4qRo;IGG zfMG!runM^_aF5QVe_t3d0+xFUtpEp!wlvsM$8|tK52xj>b%4JDx^LdKZi=5~n|JqD zvWNF=a8(VxAoiZ54025j{l72EXo9)?ZWqu$$iv-fRdxLHuUpV<=m#oE0i86U^WgJ) zQUZUCE9a_P@a{eBas)=J$>fJ^#nDC&mwg~-(AXV$;eHloT{0j|GRk0B+q9X?lvd$Safm5;|Q2~~k_X{wW zt_|Vt;W`UM3Eu%f0VJ!1w7HKTkX`sQ`P6IB;H}BAARu-h6{yLFVmIjKx4W%G@>qkd z+<-y{SI7?TPR0OytaW*z`@-id@wi|e+O7d(xav;cxT_9su{}tw^0^4KXm;5$!6$^~ z3$FZs?7ekZQ}6#UZsS`71PLioq-!A3WzgM7j_wvlGZhtSM~xJvyO9_$=|;L4O82B; zW7~7k_viaO-@kwV{Qme{*SU6W=iKLY->?}t_0B)ty(|rQCAk;E`MCj^8P3FSK$NeC{5U(y! z9$i>d+XdzQRnRB^dE^=JhF{`W0HD=SH2mMMz%nk>Qv ze*nOTo?I}cZvYKS2Hs*!{1-%}AFbDapiiEwHUoq}nO`7AC;<@j+;;%}Sb}a@1t62> z06-?j_P?M3APZ}Ew^xIF0o$~+3M}$Za3%u0R_KNdk_6CK%Q+w{-gm`k9gqRD{ekz# zAo%nHFRUa2f~5h#*7Gd@_IxFY@Z8Po!UQ%sJZr85*guI=ZT7pn`&ZQufDR;)2jJ1~ zK`%m9KpTyLEW`4=<-${D5dZx5LgQ_JKH$73VDCi`$R`8;&NBaF`LE!2ozwUO#sMI` z;eX}`{&f)*{|D@aZ!C@^!)gv=oes|rG32qtltwV_Gh>N9w2S-g-eN) zsYf#WC`kh3@;?i-8LA4jsQ4>SH8OD-rhxHHw89fNUFd5l`p>me+`opb;}kIJs~uf` zmI-`4yAY>{J$67lFU+ImIbh$YnZ~$H>0qVF1*j=^# zKw#cdLh|`Vm;n&R3wkCAy$Hkv@PP4!T*T&)GQgig{E5B)oI8GO|1Sf9I<8F1>Y3$5 zaqtAbMTk0a3O@FaUAvGRz&>HK&SJ_S|1-hq;eTcjlV}Oyi2J}Tl>pzGtyRHy}?${5(6*D*`oa|4qRHtKh79)^@ob3UzXVo zAwUDrs2Ry0azV}g--iyL{$h^THZr=xOF8rWA03QWw`B6kfGP$jGdva0_sx3ir_N%X z6AoX8<__(wtj=z#;pv0k#(?5p>p-zxE1qX{B4u0?B7G|*>}?n`Ubslv0Gc51+7=8- zzO-ArD%Wx1a*Pl4v#1CfjIY2PBwS^A9aH?(4LJ}Iz7uqej{vR&^u+I-uGuX-z0%!< zNwwW&yhVbYRHY?@1|O5e2NLFNcy|ncuzYP9C6>OQ*e9Xkd}{Ucra3(3=<|`PpM4!x z0K_w)Vc30oa7;XMDB(1;5(1_B`47osl9$crXAYM>T{1f;MfGn!|N1QN@_<7dQwY9~ zr5ykLa!L?66XhjQ%1h!;dqno0U4H%I??JGkcnSUKetezu`X$oquSw(F+&M1&d#oc> zuWXSq3^srNdI>m&EcQwMi~t421n~SXlK)>gNjUof z?ML2ZAxuK?2ZR(Th3p>*@xr4C=WEda`F%ln|9h7Y!5`WHMi?Ofio?*a4vqgSQUCY) ze;M^(mHV$<`LFT(uZQ@r8~p!`uTQGPH&2kHg?(RL39gq!n4p@^lYAVV7w%VrOh9Oq z@M=1LLCDwG&p|tC-qYIycC9A|JNy0>jo_dgDG%Mcx2m9hp(iDN?27y5(6<9vuBBIL zcUAPEY1E@=2&1=Ot-prXSO7%V<^)Fk)9T+#S5AdVEPfqYJ1le@4Q%f8HR@Oq@bN*v z9R08r@%^p5%l_U~jSVlQOP}K&Gkq$h?ps+>ygK(l zc$I-iUTgHS#{Q`sToV#);MvPGI`8=2X-#+7RS`Ldeb9 z2=3RAJ-nsg0RzXx(8niytCT$I%-i3r2+2;zL21#HuA)F>#czQ#O&6*c%&@5p_$Grq z>PGKmobj(gTu>x;WbX0zOiv!q>B0x9&iJvXdn1kT3SVdZgsaUYE+~*eiD5hDw}bk2 zhM&}H>mzn04dhZ#m&<~LM(04Zk)4%wu&Wu_15+I5$dx)lr0Xje=|1-vpN*bwY1jTx zF@khI;xYwMszAWP&tq5VZH>H{BTIH*KV&F_^gw8`C%QOJ6zV!ioPK%Z++p*d)86=Y z+qx7||Jb{f)70MqY_-G+%4egksHT&HnS@YGadEd0{;zB-uzDUmiKxz zJ(?_9WSbN;22=MmIGiDu{QR!-Ve^m!11AAOHnqed4!AI{d=8~Bqo znBO|Ks*Fa*nDk_|4X4Aan}$r%xQ*3wIW(4KXBZB_Q60;z-{<*meN_8^^R@4SUui8m zoKdK!M65#aph2c*QzH&y9<*U@c=*E~D?Jvxb1)OA=}yCKd2U{{+LjjCmUUJRns=Ox zNe^5|&G6u8*xqev)N(fz6KvhBmNv!~gbe+tt*%zr(^bfX*N}7FZ%t0HWmFBBV&xl~Bth zkV0=b(E)6+!rp51aJ>a{P-B)cbd+w!Vh2XOn;SbCY`$skzW5GJTw%frDZgB-v3Wo7 zaT_Do1-cJ8U%y0g4N3-&p0{SM4lyMA4<=Xf zXB;gLv5P%}#od_ux`_MVC|~@dZ<)OwdX`iGQKEByv4}$|k=!f$OjKGRGous?|F)q? zU>$Y`vXEcx=QM)hRFA%o4{UyWLg#W>|MRNNimU=gciS8X^YENG_vFW&)5PgWrNRlj z`XX}gWNIfZ_0{ou-Z<9BvF)T28IAT49`KK(Um99vo)jtTL+LSRpiUBA@y(->ajl(W ze-X48qY(4!sTRy5*PWllH?1nlO4E~kFtNpW;ll%PKO=cZO=aK8X0yq>BRhot05v@= zC-hIC5R?B$#$uTfo!rki{cEJ0e|b*M9K{`iRU^6O*H&<&m7;SA=@5OTVAj0@r65p7 zvcPZK=ocBw*W}bOp!uaB?=N?0eKmNtzy0#PI1?IvqOjcF>UzzMhC#SBdxmubnM9s{ zQ`E2#>BNL4Vt0Qbsy`*wt)wp^MzSqi@F5FG-*m8`21f0$&%kdT`o`jFKS=VHHsY)tLsPpZ(DMg$HsO{F^V96v3F1|B7fEV{akgvr$j zc9KpoZE)&<-kmtizB&n{X_zR@_(1TqQnurmsOz@#SsjNbcTE%fE)%bjRgSG=$5}A7 znCT}UlRex9+xYo+nJ_Hu9G-1ZG5tmvW_D`Gta~G~2##j-_jR`1$ouG+?y&dsH=FrH z`-7FfsfMK2)?IRS?FriMoBU%?B_ReNM^2i73VDaOcqeLHJXfp;8o1k0*;fLqT=cehHp_#(@+;x#?h#!j8_X; z5^cpW9O|CaO#j^1=cA1Nq}Ne|B;m!;9AwI)v`W64ix<|g6r;3`h=&;UN*&FRt8#P? z$s-M*^hOvE9v>G3u_&oM9eFcl6qZETa}jv+BRoG2`P)kfQjYJg^P~d`69S~BPLmpK z(0}Ikl7k=yHT`A>>ez{-ZDN|$LZwdqV5j)>s{xN}JL^cG>v3scx=9EtFZb}iX+~2n zE#wh%hv5DQx}Mi@$Wc}lQ;ZHVf;@XRGWB}iOwI=K-BMG6;)cP_a4!?Avr(dmdq#4# zLzH2Q<;0KY<$B_28xC&&$J&k_f}25Rz0MC6CVv=ZolZ_y#zH^bi2wb``?Yk^L%-XV zg#F-^xrByEVl(}QBcoY5GeP9HFMoU9!8$zm0#{Ict20gP|@vzgA4DW z_Wefw5+%z0HIL{xA8>YlrS~qngvyoPO_r{3;UU9=?fY9S9c%%^#-TSbpuf77w}Bix zgP~{OM`;}3Yw_+VKLFujAEVm(jKdF#q0jItj7<&ZM43hs8xZYZ5b>wWJ>AgiAMvh1 z68J=NZMmWat;$02flarT2ZWsH53aM|qJPZU8IM{U9rsTpHhs(zXWP>3t1!m=-Oz$6 zINB`PKuqZJ`R2?T2Em%Qf^=q5NE2ts?2!O8?rJkH^=Csff`SL8+4eqI(>XnOQHcy`Pip#OdGDZIj2ih613g^giLn{a_toK@kf(y>0JEE)nu2k zt9PBTs$1%!fe*y6CwD&7_aaI{AlUGwRhNO)u(@hDIJ10KK(ow+7KOt>&f-SH*PHKP& zaKn^fkCgiO{zY7ODQo@b2*r`l%~Fn@K_Wj^7rJ^`y+!#8l7%b+IK8={mAwK5uXfpN3f=0krjvI6^_tw&%yp_y5pN@v#5k|_| zkN&tA!98%cN++MdMiTj$DaL#&fp0(*(HP`%tz#!^`Mn)YZ3iD@!KZIUyOi0~hR|Iu z%ob@r2OsoVXNm@>XShr|@1j;2=Kg|(>>Jki$&dNE1|L@k>Sm_Hwj%7}t=||nCwf10 z#<$TB(2C3NlG;N^b74IcIoVVTK|w zN-B3npt^6@WET?qjKeAsc4tvVU0gN(G~^iz3W0fP)<*&NCZRA)jOp(QXUo@w{Uc}b zwT=OmM<;V7)wLN5lv$^{(oS9*nsBa5PD_n95eZ{9CvqJN8n&;EnkSUuW5-+iePTjkOT-n484qRF16+U&4Reo1OvDEBEh5)nQfy{o+erRIN3R1=M zwB-^WKM&jb6H%`ILfY4ToTua>+q@M+8y4bX5c`ED4Ke+rS+V04pYEi;X;-e?NF6Vs z@JB~%8-F1obKF^qEa#0g?Vae-un^pwF)s`k9r=va_-^BS?s->LdD>9wkmcIlleA$> zZ-KU?hfaQidUwbYGssYooRU~d!femE?sitqg3rL%W9s3Ny#FZ8W?G4x!hF7vp8KvV zWqClQM4>*iw5!}~bfJ++0fRkvlqg~@#$>Qz+qq9KxY_=(J4 z!*Z;B3|&nisequOMTMs_ zd|T0vd#wvwEbF?x7`4YBxH~6U3d;u`8^F;tZHU;~^*o(O<0#jU6-8zEsh!oW|JE$T$JXij{)~IV4KoNL+WCIB(eS5tb?+V% z+}Pey9O)w$^sWtb6PvTn&{oV-j{rtXm%uu&I23GCT!IC@INCJhYbF zIy-;F=w?#LO@U=DsTZC}jQ(=DpDlKIMR&Ka<5g|TN_pf4^1PeVh%C#A*$^_t7jwVZ zk9F3Sowm#~T-FT5TZM_0g)B&roX=AkJfUWuOfxcK)!1nGv`NbPcH#I^%AP5%_+fDs zHNYBGLXL|9$wCH&R1%-(tAdnQV3urxss^#wn7w7nM!srP17s|yej-;js>})bz%xF< z@ble)g}O2kpTu!73t{0=C#^q@xf;1KBPf&NdoAXk)Ku=h7rZ>g-_{!pY&kxW7E`&z zC2^hjfVs-@D8J`;v-L#Bi!9g#xhAT$I#O$I&zNUIaEFKBYKhLoxSbp}s|9`l(UB_Z z2c5b1oPVgiXubb%WmM5VQP>dL&|u_Tlp!RWDZdrRb5DPnTzQVeQv<-E(XC>JhHIOY zd;d3uNW|PI)PdYxR-ISa{KG9o7;|2OIeza8h27XpP8oq&C^bsjVn#ljYx}bdqs%s6 z3 zpio&W%uGw_xOC?(?FX z%vrF4`Iz&f{atal)blt`0(^a1-{GfT*OQ@0dpm+zl_(ON`aXWw1k9(gZbPA^CrI8% zrNo;tSwxR}1EsSC=`)p`6!x2^xoP2~I}gDljCxFzOrMKo$#Q?K9DC#9yvOWHqPHr= zS8CfNw2{?T_(or{k*;PMtw5R=xayj42pihC2((un|eL)hj?3Iky~I} z2LW~ed!tYRUTZTyatR5^&j3_-GG3|-8C<(%03>XoLtNv=U!nW-l`CiEBIp>1@5mqh2U_)>owqL#ZysILw7^z|e{dPKW^YAxj&suv|kw#;z zG$R<_C(p%|65tfTs=(40R8iG0ygs+KeayQ>Ia`zb`lSprlY(vL2%GrJg{DC2fba7K z$sSEcGFu-R5Lu-wN{P35{;FaIVMH9o{(|%66r=uf3+pjueI8+i8KolDwuibz4Y_;a z-OiVuvATX!(IChz1y%dmyjE_A)1Ml@U3#^-+$rZlN$ai*3BnIO{Hy*p!VPV!(ejeU z%_o*BjM=UmL&?~N!}pa_h|v8BZ@)vQ3k?Hl>EYhpL$6mJ`IoOXj=~+5)yWJ<^C<51 zMkpb=`W$^fiOw)vh?<_*s1jC;bz0VUZ_Vtq!Mn1E`lK6+zj0xQ+o)bdMjD*C>h-P; zVyMn6BYd@U-b|W~aLR>_c?K*n(dyDbZpIg#*nUb3w`Y4Rw(9Es>b;I9r)i8+VhNEh zDi;$JxrR)Nb{r5*wO?iWeE9^OFX5*rUodZF^YwmS)@H#xWkUTp&usJU5(A;5Ym)3- zXltIs_-3J~#Gi$wU+Wji6@6xOkuiJc@bprsbj)&453S*$`E+Q6qm$Ar8O!84^XTul4;{=s^y)ZP7rB^*>9>&C8X(fT_BOS3jNd7(Dfrn8lJ zFi+QX4%c9evQ+s|Gm|C4VRPFZ+)rL8KZzYJ)_qoZh?mS+2-~H&Z~VN-a%zXiW5GCxw`ig3CB6?{|EV?-pECB$TIRK^f_EYt zf>Cc1RK9%7!KkNfHHnm!GwvT*w_RKj)AD1d?+h0;nRn5#9`Oe)w~7qb6a_O`oq-l? z4`wmXjcAhjm*@cbF@jKu9lI@Kmgf%>Y6jV?G|%1m!4+B9QK-;V z(zMxZL&7zA6hI#?%*pYm5J`SzrD)gl=nLh=wY z#NYs`*P4{<`w;XdACw$2|3o5>fq8}fu7yZ&ENKCAss*jS~IzifJdQ5~<(6>q4tTSud>>skRS)!=uxn#cZ4Q4f}WwI=xKh%lB&i;B=KP zg;RE<7}AsR%vF@rb`BV~m3O94;H$lZ|5mIKVe~M%@)fzPwKeaR#G?Ch9wA{r^~UM^ zsfv0F!*S!=f*ZOxvkw@MtAc4f{HSrg|4ZZXY&KLOf#wZZT3&(5Ws4VkLodL)MUDc^ z%z2!a-}nlmB~)f+Sfx@5FL8bYDGUaM;NV@2*KBTR4iYyx=VZeqMvd8^;k`X+mFG6* z%q-~=zmVGlC3${y_sZ_b7~PXt8usb^rJg$1$z#DWRNcFY?ig!ms62jD_QEc<*i?q` zoTffQr!2$Hq&*;<)S#tU$SMiD*kgQ@zS#gC=DhRgt5&!JI=XD8E;9BWwJH~K zsMK(q_pjq@fzpbgVkjX4#9zTvHGF6)i`v2rkG9<*<`Zs@#%{)QxcX&<&*PL4Cm$_D z43{f4;(5~Ro0=yp+DUJ>H$-NJ#`eFo>%=Y zl_{x)1R>``&B&mM8o4WK>FT=KOizGe=O@vmh$oJ)T4f6Q{H61()A&KRoUAoP)gjZK3jyGcCrjo%P`y46gbRY*PHv%4ckA@>Iit`@Ng0#{GUd z@y*u8PjnUvnQA@y<=yG zmV4Xi+2CDB&YR9#h#uymE8)*SoWo05qlGUDprX8l4!V2kG4LP5=c-E@&X+U*t}#Wq z9&%1HUm??1R@*w(^HRz#93Ae1R>)gP6lW|n%M*R=^sc=T7E-U6#!#ysLVN5=XVvSR;v?l#p#-E08{ceoFJ&Umz%)3CeztSWI> zywLT`l0xsFD)lIRhOb$UhI5q9>FU!Eq6R_(VJS4>jZ5*`D35+{GGRvkIehbE6j+64 zz}ga2Lwy9S@ak(4d93J5k_|F8q4{0|Tplx{BxK5<$3T!TpaGQ2xVBUV)R2)NHY@AYb09mT9gFh%lppwWbEqNCq10s>E2RxNjN2 zlz9+u-n}n6Zvf-TI*csykw4?Lx;O&vX zIAsQqXD@Vv7M-*G>wCbLa_-UDQ4+)kGj!$hq9pyL@XelbYKcSc3biUzZG41Kk|+~d zP<0mVZ$>BEF0(s(dS6YnzjOTNz!kLlQ9v=&5x8A@2)eq`rOlufy9yc9%Qs zfh5~Ec4CGR`V>S(59{lO5xI}?b{i^hDQQK`R@^kEgElshFE?J-A8UIu;e~bxHt~Bf zK-Z#GR^OApRVUSEW|H&FW)|A0yinKa|vLR**rl zjD<4aH3>N=Ew|QJu;FR@VA8wj zu3Ly2hvBf$`4+idgGFDM#3_jJO7!j!P4o%TEyR9R9z7&Z?%UZJr5j(CcxW)|x8Kz& zLi<{$DoW;X?{fsQl+Z{9}QKC(#bUki+cs> z3=`N(Mikt{#-~M9v72NUkPI-23PC2ldgBZSl+kQ6T1%7RIz+EvF%ciqzh(8-m^oR( zA1N7mgrsG^9KZETKkSQclFqKJu8S$|At*N&7Ia zf5>*QuhkA6p1Akv24>W8pc?2g7&p1JA8V%=R;%`>RV$M&-DY6r1u9^@(2B%m=j;-6 zkxJOg3x>_Wp7oVfoqNkaXwsSXT^?nY)zm(zSqLQriL_ZK1gf-2;FPk$M@dq5v_DDy z(X#S(6!ZURnMf~ z6bOchWS81J=>44$TG6w}Eo98J&F;iz>f?gqKzBE}eJg52w{z1tbah7&fpgzbqH{=T z$=%e#NVJ*QuJty-^@m)n^_r`>`1>(yu&_r_ClewqLxa4QOp^0& z-0ZJt#ZPXZJ!d%y7ECCa!^p3l(+u8M6(I4v@>Sx<`}*hcH`%|^@nanWiW$}+=VUji z{9cCPchM)EH$aW zlqU!DjwVM7ZGYfBb?GY)uu(Srgq#%E77ox*A>;%AX-54wYbQxUPW$ME zMp`%$IvPJ(N%p-LxwWn)J&Z7#VQ|JKk{s{zzTiYJ@?};<ratXK3f=Bbmjj#?(_e&v-} zX3{>}`E)PZ7{>$6es1LtPW}3lq0(mq;q(Gdh3e}Kwh|RD z9|V=pW1C*t1?oN9TAV)!ogtIOs(rXG)MtK$sFYXa#I#^+fG_ZxhWy&I@lA>i^uVg? z8nU|3REC7mA!d2U8+q_QeQWEcli!BtfU5#-GQv$lW(vw1uStJxg+SI`{0K!@e=J z5FI(!%26Q|o3Ld)G)2HH&~>a9Pkpl|O$D*ZQGAc~0s@l()ac;AIyo>j(c$a3PuB@p z@=t-<_7{@!l@oZTNJhtaaPon7xto=MvF<*;_<62`<`V@ec+g(ao10mWER+z%KGr}f zr61n(Mw=-8JM&iTSc|lylHZ3jmZh4s_;eYuH5RnoVvh!JgOfnb#|;(Eeg!ozYig(O zC_|l;;AL5UeVhdVA)MmWNE6CNl^@GkliZ|BV(dNE5i!)VFWoi*tsAW5koR4ns0Bp`6y6K{)Ok6)kwKI)B$ZOvFC82)fwe5Ha z5GmY_IW<^4t2*cOJmD0K*iO_sATF z*b_$nk6oM)pqbiSZw2(F+Dv-CCGpk?p^2c->?^pO4N_xc1Ao7^X9z1_so_BlubaMZiJ3|VgnC#ctf&h@d~Z3 z`P6|d;oC11qV`Ys3+=+L?u^Vr1)8ciT52N{_sWZuta@Goo=V?^K)WvGG8buPA);m7 zP0{zjkOLmR)xlNI$jm zcnFiGUq$@7#2{<3{JJ;+`D5@C>LzHxmhSJ9r|)x zX`r_vOo7v*q`L`6^JwtH*i0u@h5~Z6<|nI}wanv2cvdhz(=P1A&d36Ex(6g2Jtod9 z$CxBCdaTU9&8J3X)ZSVeEw(CGep2l&h8ezWoMNIa&@{RQj-bD_@UmDl4mZ;l^6YOx@*hEj6=4l_KST z@6$6z>Wx`*ZTLssR~`4jYeHPy+JZNpTKL54BDI8b3{X}eFT)(ofWl|Dc;B=lq>o6LnKEC_gK$rbjrg@$RshPAH34r zi%;37j)Zw?1y=!drlqtX?eE6-nQ>LE#yRZN*4?{kugF(JFnbsfEz>G@&U&X4u7CfrY3an(uS3O@zMrFmL}s&#=58SVyx& zfBRfh)gk$J)=1O3Ki@&?;fd`l35SF+4c-z*FtBd|3#vpuZ@PRKoD8W#J9qoTM(n(N z%$I){9Xb(5F|TMtY2DIxM@9zLA+c2Zb}z!`5bZGLlN0cb&Vb9@+oXK^9IT~O^+>{F zeiZRbi&|qM9qUXmYAqSLuD9rpxcwWv9QZ*FXt+Qi=F5^4CQK7J&51ne}o0m zan2n+=SJ#@Q%fX8rW{8VJtX8zO_oT0X>_F0@Xvco@$~Iz;^Op_S!d&&wIcjAD=a8< zE8o*)Xd{2$M&rd!%z_qc{tcg@(W0o!;j<|#?_UWIc(UJprg6l6Zx#yEQ!XX{N?-G> zSJ#qgG%TMt%e+llN)MhKhs({by_%&mG!T(!kK_lbLfwRYfTUgX41B)jCd4r2kAz;8}_q&6ugDP1mS2cZXpUMfI2Q_`dJHZclF^fIw9}Y@V6y zH~2$Gz*X(3D4OYr(lCw~P?>+c=m_7{bDEyM`Gbx7uxL6+&bPvQbn8;;ObO}&f848N z<)y4rsK?(HMJJS8PyEX0u=Vg}>FiWX_L>BlreiPSsWfMm|3=qJk*NN=jSLQ^*MJdF z>rNLRHq=4X0iUlt1vnj;))X!Sxj16KK=IF{*FXaT4S8{J9>R-&aGz{wLU~;p)UWay){$H zBcBE;6GAqfPNG1Yi#3>6VQ_+91~wpV#Dpo%M^g5^*`Vn4`n^R7@!l}7NxbeKGnz8O8E8?v^>{i` z7M3BKPdy{Jj{_&>+F}ywqW>lbeE0}HuWvp#MUGVvE#eS<*w?#-Df@h%SQ~TtZO(jo zy2W%!Ro7~Eh^zrCde;-7WrQA%yDRY$4;$sOAjqw~wvlG(3EIj>aZ`;YjpMNt_tS?x z;;Rb{aY}$yc*|ImppmkhuO7<{KIZC+QkWNaJoUNX8o3^=daP@#Jm;_%q6vs`;?wTlPjY$5&8puE#|=kF=6hb?qq)D!BMv2d(*CW2 z0CXOBo%O0?;^cGumv=yIi zIw|;!$T0ZPbazB(uc{4q6)+^J`+Q`G_rSrMbaP#HeJa70L^_i*(fyxgv!BZEd%mL* z$!$|K%z0aq?(658jSzj*YeCUsdz42z4IKsoLm&2xC;t4&p;Kbfc8gj+>2kC=B;*t| zbKdSq6dt;)^?cR3=%ab57n#tFXV?Pl`=E$e+CEIN>ci7WB=IEGm{JCz8i4{1l_f-*6C0$RT%c=m;SC zR6bkLjYNV+^0-Dtq2u-Kdk}z$-EX!V1E0zkwz^`e-fABK-c3QoN^@%Gr8VQi&Xr?M zkE^wR9GVgav>@VbT4d~2ptRF9krl|3+x=>&#J6sg<7YSGX=kYCEsk5eUhH-iy_HMS zQAE}sfgPL1W!e7zP9qf)PFjzMbTax=f}Ni3i4@+hu%TuzS-q{!iRw4`g45M10GZ|7AwRR~usYcP;*E%zUE zHz|Dztle%N5!L@X5=iqy+zqGDkSVjhKi>7dsq>1;Q05xo`5wW?G=$uC4U6HlXEd~`ENsXS+(STz$(!1?RpgK5G*PMk&#L2$Zp2Vd{Uhoi^2jeX#xR5U zTQ|@W;q)l%2JSsUR4fHW`MsO>;+4$;S|vUpkA{o$d%#&F1td#LL~A(+H%RGSUn27f zZmOIHvxg_*`*O+Afm%OToHzF^=Y^g}M1tIJ`x`7^_txXnw=iM<&+UEN17`=Eq6AX!DcvV0JekF#jTpfudzb??#m3& ztvWh?;kbXuX;dSKGkiZ77%9tdigTIhAl(C3UNvi;UTWCL*Nb0Rr$2hfs#u)0c_pM$ z9_e+PKdjZYqI%Ov1jF<-OI(doA>v^9UHyOx$FDu0?={L=qHKUP6k5HuDh}|EXRpz> z0<;|H$g5kQX=d-@f@~+X-7JnjpGca{eOtP1vsH-jyq@OTKE=!aq=*V9xA&C>y_?oM z*F;TZnt`j5aDq9^H0f=YPS2@tDdFM^owX4AB(pFlpx)JUWvL`x+tubtWi92fpJG`D zS*t7S0Chr9NMR1-`%$4U4m~$LRLk#MJkHL{X&0~2K-&{H1`>^O5vcB#QcmOXR3Wf3 zx_iy-I3btwX1TZd<~~b#PXVU(tK6>Hq8H2<1Dei=VnEZMhn1&(Y}jK)ey*P!?_ZKE zxh}Zhpb|E^1FXWK9UIdiGA4u)S!aM8VO^)@A*4X5V(Cf7dhUy<1Yu(Nd0*I8Ic=E@kf*ow@p^0AwU#z4%5x(7{m65B` zD;q&Jc1N5YINRlFv9u;ZHGbRy>=En$OUB5J1lOn-x5bv$NFjtuboQNNt_j0Jl9di+ zq^8tCKdrm_8sE^64NySw8@u&SH2ye_~OKM0d$cx8y6g)XQd8fw3er$ErcuyKT~lP|YIX{gy%hKOm?n*i zz}!2pC$9DgV+NZzWqTq7M?W%?s(Lkg!Y=P^H{Rmd>MBl>(@%ze3lz!-q7f0D?t3cx zUHwF{o^~T{?2y&opNz-zgd#6uZUjp15=ng8>%YbU1?xEV4 zZX&|+=|P?-F9QwC5;B#HGcwHdhF{jd5;DBWy@U&`zhc=DP(LyK5EZa~4+8b_LU6(u)9pZn;^liF6N zMr9Pmtg(lw#wrvw>S5u1UE`Z(x=@Jgdy3K2ocI!W+zNT_i$F1cfcVof&+1H4K7F?6 zczD-7<=u0gW%G{KmX0OICjU>DRcDlB#%?s5*@pZ~r(P)(=Z&Lw_L5|KVJc+r{ zf_%kjM-h8V@axJjhGNNIspHr!j>ufTMG;>>7JPH-=3G!t?}a?pBF-F5uQE!eQ~(txI@ z&NQnJ!j2fP7fMk3DsgFr2W=43m&5k0gRDCu%~EvGBk_%5k?7mMa^Uf$fyKn?fAPEP zysZ>+^jqmR_kYAg=>|hWW~gT4$zyDJks2j+P<1z3Uys7{=4AR}H$C76Tmdf-A%mK2 z(-E?-f4A5KceOB>ec~47oTyKRu+ak>T`%>QSO0P%%BUSZ{BXOs@6*><4o^lGNyO`U zl{DvzTAF(-zHcHA4V_?lD4KSv5VhX2{Z6qSM*l<)<^OA}(`+C&PMVd$RjQp`vI}tA z+TM=D*yc6MB!@eaW?1> z&yNC6_0W*qNyp#b14qW7M@qA9w0cx$?LN`stWTgD;$uSR`}fs&)2JM;Myg5A@Oo0{% z$`@(_9dDduxr4syMe(}S$A=SgXa9rgBt6;T=Z~o{%?6Z3_eCDSSztCyNJA|+fU_DHPN10Qr~#G=oBPO^krB9!AF829sjJU z-}&8%OVgC5mUu~huh2=vlcgZ&Ai4g#sryc(x^0&F!!*E#PsvnTZbVtqY99hYeVv>h@2GPi=yYX zIOQVkad19Y!`F>FgcYA66%SrM)0$+T)cA-ppy<&@x^|DBOkx;b@TF!J>sZ9W<0m_x zZxJZszbt3S7?n~<{e?U8v+b*sJ#Zm75x%G63bc*M@;U3g2ET`Y;3F#DTfpf&N@(J} z0k@3E4?i#>d~#jCmKL81lUBmXFGoi_`yaph`L#}uMBb}LruQ3drN+dz&M|WH59M~X z7OgM5?&sGBKg68Zd}VHphMLw=tKRL%bj|KYEZ$EJH})}lg!bn0sJvhjayHldbp1(* zUzO4WRLpZ#)Id3+m}KA$YP1W=siAEqrZVirLIk_@|LPGC<0uuAN)y@Z6hEZl+eAP_ z9qQ!Hz7Nywh1gi~+8z5CuEnz(Ql@)lov{2v6D2Y2o%JTNPqEp%7kF~gvMZF2{XBSn z`pt6hbRv(+6PLDjHf#M3vSS34Z$I%Mw0P|7`gElhaQS-ipNUc9qnfag*S?h5{}c0Tr?=S}iZ?0p6}wKh_9pnA%#>RNy@fp9_p;%U?woCP zP^oy^Sz)lcjGRCbqLl2J2eR4StN#kFk$?${DEu(TPcmWaB2Xyj*!;GCXL?yQEJwNYx_qOdrmg@ zaMS^QsMPy_vHmz&;7d1P+xRMZUuShCWAFH(r#?#H~poRE#Xowsh@32z9HTd+R`6 zsWwd<$6lZ%Vx7+cV@9lxxFQ08U|qiPPa5>@sMfG-vO>Kc`?BP7qEFE{Vj+A-cGDbt zgx%Z?w)xx;A2chy`eSR;coc6xUB4V!~zBdO(CzCgC% ze3F6?mr1vx*{%fI;@}AG@;qnlsU6oUFCHWhql=2DhhHW|mIS5H**sam2_s9qiY zpZ2aZs;RVF2e5;Jqk@Q1L=>e2Q9-2+2WFVU% z=-q}dm6rGU$%G;5Zmee%R=YlF^SzPeU%BRBZr0aHxWX%L>NsmpX9~Rk04XBtPFlZ< zU=?ea@cYYD=s24=(N9;98I>Vcp+`Lls9yt&ka?>EQs#lPU5=yqm1 z<7#gksuNY_l8cUCy1S0JPu3B|d@2UT9$IJ~dG*mF`fbMNTAZ5xVVQEt@U9(aTt0|J z10QtDz4Xv6#MR20Y(1d?Tx{H{DPNj;T<ajJ~00ZKIY!33Bw<5La@oBJe4X<+Lw4Mx<39sBE{W2{-XC0q89d)%u98}^Tf zBbfexadN$UV6Y1X+dFpcufJ0n^TRHO7kJ6~V*_|JCEb|6ik;@i6D#I#0L z|6&AKDy*l-BBh0KXSSE>1O4tezgtDMU+0Hjae!X`?KuAdKKT#u$^RMPlm7)Iw?}IO z#Y3jT*}>Gc{b~6zniVD4v^D&Ok3ua?k8C9-LJdOST_eSn_{oPW@3PUI?U|w7nOH-H zwQpHmbGL3kxCt39a$i1XVWwgmhv3As8kir(0KiFR@rPUccBc}1Sz7rF$1M@A)D5FB znTe6W@oolc>tcXM|2x~$A0p<4=39sAtc*d_;W;f4w@9uY!TJ&B`R?XT_;Po(XbhWW z^Qa;Km}3Yv9>u5J=skAfS(%7*MTzN+YNrSxS-1H%uyDHuVLO###A{>MkNegT(uZJ0 zU#0N254CNlZgJSsm(LXpt5Wn_c_SYlZAxu0*o>-o{>7SQx?9ifVK-*WhCx$P4sFCo z+D=*15jx_t1P)<~(A9LZnSQ)P)6} z<$nKsX>EJVUYqQNQe-^RRhxy_P;QrpQZqwPGq#_ys7fC;;|=DGV`74jEXl7o6`*WA zR-BImfoMPe@2_V=Krh2ACx6Sm$JSTR#eglH5GZSG3@V&!5r8f5T5SotYkbPaNr@+Y z>B!}SlHtDWJp2Ym6395a6{W0SI8VS|__)ax7EvsOp1TY)h5X>K=aGFWKccmombLUW-RBLj&}n}D++{`Gei7`bq;dbbMnR(&I8B*p{3W+5?w#JEQks{c1v`z{^j@9 z0IDcyJ6g~=GqO_}h`nOGdA@v3cMX}u@Lf`-sP*rT$c-2VJ~56)^HVUByDzUmaaMo9 zjnpf_u-5C;A5e0`8bewKRxOGM#Zk~WBEPDtr9NA^Lz!vWa@g#K1NG=-j$*7a4(4z; z8l$kDm~Sob31unYd-NjN5S`A(k2}!bz$@AY#yuQJ7Ul3-d>D0J4OtSg3bx`Xp45=p zTfh9;qxdxtB$#R4Xqz7ccKA$|!MtN`^CPnoG2-EE`s$VP2-x96WH=p#V^aHmW*rim~tShH*9lS$1%e>Qx{VUng zkRvtpjCy>X)1!)=oON@q^6JxQ$#pr?J{p$Rj;`k5BR|aSk^QLQFB1EVlgMZpR2fjC zxW4i(e$w26NhvQ!HPdr;l}^l2ZQ0zbglswY_baMT3>fg%=5G%<%c+V#8DeQdK9QS1JHNfBF|HH5pYb|Ced2n#U^;USxI(2Me0paj%Y9%Bu$wi7>Bf{JRaFRTMIJ?QB|Cr-tPI1@$x50`6V*S@%TkLlZWhY*+=Z2t*bN zoqP;7jX$dbT@tXrCaibTz9z;_kxJ~%n@Mr#vIOEAH;zLkI8E*uIz38A=u?vA%jnl5 z#Y0!l_;7_-lGD8H3{wV};I*Pe0KNL@%(MKe?+&?;j8hbydDmD%NN73Rx=&#U8)h!gD<=P; zZrSAQFTzK!{_OTW+14k}r(lzvG*Zt4ol$pS>_&pql8F_^aQ@6@N0sCU z?r&}Y0kjht;#s7d@6eQ{9&~ckn3F?we$iwv(MJh<%F#6OB5%|E{h97M9;5{AOY>b~ z7TD9f(oT)y2tjnIB!I)uu`YpLaRR@5*M|CV2;Sl~a6GS1(%kb>_1nYeGsH;8C?V)z zxIMuFG?HrWPn~WQjtZ65YQ4im*^eTp=8Z+@vfZGV$l`Kpzd`tq>>R#LxZJHjhc>p1?rqEA-nRC^wsaQRRtT1Y%w z4NMq3HHqo!Kb~7Y>}wdKvTjl0opSMvhDtVd8WG5=$+})!f#bpDx7~$sq(%orbJC_d zy0hw<=tecpuvi#wQBZSvpAR|IuZJDI^6UD;Yt=)Y^Ce&IMUV3`!QCmROV%EC7oe8= zAzaA7cD?Z0qlf!2=e0K@`^K8W6Hlt3@3$wnwsqvTU z$B5nd?btBJ>ehiA$w#vvA=jwGKCQPDg%L(C!p{i(q__#7cy_JNt!i_Ofv?!I>AAwR zg!Zht$Fap5D6ql&SY{{rWXJM5xJ0$W`D|)Yc>;>&JzKP-qEivhbXqd`455!QX!Jpt>f+aZH?M}$(%cp!J683O&o$(VT zn1jldTKo#S$@1-&Dc3Gg!Z=M+3nRl9wcp2zj6XQ~QN|qSI75A~xqaWEop|iw{^pcv z3ag^GE#it7Ugoo*Yhy+HrDE%KhD#sbYDP8XB)|ML?5=*FT;4nvLwyHkN zFo=sg&B=fpY?#6p(_sVp21g>GkpU(GZZIvVB3{`|z$l}zk z?}dAO zz;DZPEKVxEdry7LR`Fm-utB2`PTB)|&G-?`ZBdXN{iOASfbY&+qw5*8T50LVx}7f8 zo2QPrAUXnSqU}dgu1W)zeYEoRe6g zltqs5zxPwGXTfq#oNgeadvx{<*uu^d`a4Coer*1QqN3iP;$w2kiKZO;FOxK%ZomuFD1GOQ2+LDHrWthQhp6lKlMfs=%cf*4(ek zp^3Ce7*(_$NPK?VPQ351Y4MdxhM#0qxr+~cOXNKUB(%=4e&P zjP`BlaCGEk(*%k9vFbn|PWD~uoNNe|J082dw_i?KTX{|a4B(ex168I;_oIfvYhd0B zlGSPiqWH8pJNooDRz_qS*X9NaOl z{&EMPHO{jEN@ra`@b#$mkT;>X>5&3mGQl0}TdMcW$Ok7e2Y2@Ab`++Tf$Vv5^uJCE zb|kRk{e~#Ex4M9kzZ2=UD@>k1f?rTB^2j}6wX*aw6PmzEpQpmpZiqI)6Ov)b+wiM@ z;|-J=pW&8HE*LW#$Z`n5H*_TRV22g=NZSu6#3DW6n|S08^;QqYngiAGY<;QKeaCGJn7c^*ZY_Qy&Z+6bZWk&;d9C2x0f>JyuG4W|y zf5zI6HGzA}l@Ivpc+N^Ej^69h&CuJzFK@S?RR#XNLov-HwN`bS-w21V9bqzTIAijh zZ7jrSu82vL=bjY_N8%z2+7;%7RyK8%1wlql1`Yc(2QG$_b9FgV5r5O6&9H0L2J{#+ z(n>v(uI_GK{dcGVoed7Nmbg6*hcl{;FcJP@vv#JKwxcJ$7pJ2%thQ`dfCbEW*(fFh zC$B%2o`z;e_k&U6-5&rq|HVIMy3R3-;O;uQ3J&5JD@!g${l}<3-YQ zTW}k?=b=I4fVoHnzON!4(t2L$3^0+v$}~)L%sMn1M!RrGBTC>vySJQ=Ea-9VHN(8j zI2!ZoYs{Cg@XyxpGXpzOjnB7=&k0pEppm1dH`f=^t3;aN37Ky%l6UxP{f$rgzI#t( zR36Y+(-fz=#3Mg1OOyd+m&9iRWdoC#mXAqvW%lHy{w%!~5p9+b^J9pZwyF~sE)yyQ zaT5r=V?5ns zNBqqdW^vZf0vI=I8mC|O_``HKB>8O8jbNb$Gm(cYKscYJIP_b=X;-edw56NWkM~uU zW4~c!+f&-(m)2#o{Vg8)S^Q*?GPdu_AdfbfQU}6eKG+Wiqg_8FR-0USsdbg46&(t~>C~EeHib3?(c;R+v8$O0KfEoFwyBoU*wOsdYR6DN}$a*g*v1aBMXRYlix6t`#Rb+3^e4@intx+K<7G-;6p+&cQ01A z?O;cLs0osfTYTp#J`v1wG|4mI&qNm=_s$qCBzAa@O@!1eSJxoy-=|Jbo{BFR8+Ad} zSRQg}L?d&ekhdM346_JAXNgtxt4GYRq7MuZfK`L&4)XA92&r(Hxk$||jJi96(JBZ^ z<-CmTATohfQceZtan)7PC}2t~l=~baYwDQ>vNmzaN#{rUXM-2YtAZ>9jU>_KS5Z+M zTZ8A@CVYyTfO`kdyHc8a4M3g+B2$T1v#Sm^z(3DbRS1-YSQLpO4()SEOcUPRr8O5~ zJncNr?{_u&@bwTi)#gPl(wH)CNNCiH_B3NMN)_qSz;ZXrfm~6N*^YwjnF2C+CH?Lj z)n{^11?v|goVmAC)8ucx^K#(LiQ*hB2_bYxWK@Gz=CcpII1;G02#w)|+O^YplRSO@ zVm;#-&l^_vi_P|@m6LZ;?v4#&hZ^C5$>0EKxGsr5cunuO&4+ah+2H)Q&Nhw&D^SyuV;L&c&+UVUY)3r~d})L)30Ho!^3+VL1rgq;D>c+5 zG%rHY8Tl5zecWqXej41pS#{My{3Ryc)^9ZU=LzHeX@Y6V+lx<1JVu>iCT~BOEOQ-A zHC2DIGkpmmd% z7Le)BYO@B3^xn4|yrqj{cSFiy>^rrd!iF53lpE)Pwae6lz`bY=?lgV;)9lvLP~Rh| z18Z-sBp|)>O z&eIylNUlDDta6K@!Ed@Px^Wj@jKPccnm*%7JotJJ|~kqtbI~AKP!|t5Ng$#=pa6 z9R}8IZCa)i&D)-pF#{}~ICxiwthvcX<5aTOV_rezL8z~_p+Ssvcy5au1Ipsz?h94L znHWTmdDDbd()@a^T@BWU`#czwM^E;*XywX$w%&@gFQtyFMLX8JY}w| z&*pV%ukuuY6}S97)ywQ!(im7TsI&9&$A2W5>7<>&LyH|TAt3_)NDBjkz{@wJYV2@B z|I7{(KVy`Fu^9^eL=;j5C~E$O&IU(AfZ<%2^+y$ePk$%K`H5jB{JV_s|4BwYjo}M} z(cRcjdH<392Gl#$>q2v8@8tW(TL3zr{}AN-V|D+JEI{n_A%MVU?`3v2hwbmlz9R5r zKdzGr<**hwPJf(#r!bnjy`>W_5hC?madG{be{e7Gms$6u1mLm9=dRv0u96;6D)7s{oUR&P<>LAn$Gv5@ z5BT4M+Xf~&FLy`e?l5_(Y++$xV`F1$Yinm`_u#<;dwY9FM@MI8XBQV2S65dU4CdkC z;pOG!Er&W|5Ie)8nWvuDr3!oni5d|teG0laeO-MRiTF)^{R zv9Ezq43CeGfAi+e+mBC@l9G~>lTj#CdU|?BP3(scA2KsD(ZHxq&F_mXDk>^z$}IWv zzO=OTc-kDF|zkl!R>+A3DADHOEizj>6w`s3NU_>W@l%q%jCJax%v5d z8jZ#P#_HVS;^NPrKbcJC>gppi@UR8o{RTr-GMlI-v%BGRFQ!ZEEW$z~W98L^Sb z$w6`h%W^#)*f#LpQ5vrAYw~QX?Lq#XpSdm%pWk`>ok2Uc^L%gQ|LZ%_AV5_BeM&(6 zvGZK%fb7|MvK-Jm4=#9_@y_Eh`2QQzeX~f=o<^yLb-tTY)<-;HsDKJOzU7r}ZYfO_G!pQ*EC|aGNv3vpS_@;alSl#hq$KO>zjZ zCt|B{ed&y){M?f7$H`|pt2<1aUDo4x`eCF$Jg3|8wpTSP&>F_Ny=^V8WV&5( zn`doX7~#RfAflnWlj?;$PlAfkAQ&jhZ0_o&R`W9{eM7>Sr{J4{!~x9UYuCF2VZBiY z3XfNPvySs<`6VNp9eOlVUCPus#$&}RMuE36=@Q)tOPBB^LV22*PfAG{?t-xEiP_tMkm9f=XrrO=$(xQ0PysKfFu6 zJV3rBPpwhgAi)aK{0z;9&#HjBmY*A-88eL=222`YAHLbOl;BOc+FU$Mccx5?1p*su zxQ$47rm`>($?5D{jbq}Ni(5Au5GG?|6}E4X$*;}abd=@O%u~&8F@No>3)oqzsPC#@ z%}I{7=AJoM8#a+OT)60*FPu;ERV~gUTvBz_eAC3)$Yte=ZjF+|jw#0_ACMvz5<)7zu&?1uHF>_avZg_qe9;wi;Db3bh^1ImbQ zG7Jpf%3{^Lyp}b)LM=m-+|;-4p4#;D9B}pJ^pWP!b*`=h!R;I@zI=C9-HKEh&M{CO z#D?2$hxF^ByT9a=Paovu6b8GYVpVWih)uSW?bi}NERHqIz$g4#YDEQDsxWwLD@T6m z{-zX&htDh%;Ijct)U+H0HaRud<}td9YW0$yXF%}aE%Dm)o`-GB_06}z>uj&UQg0iE zuo-FF$6CvB{dT}TcKtdv52Srz=k1H>^>1lP4%wB<&R;kR#;(ydq$94UK2^Wtl*X7m4$%$g!eo)$Is>3eowOVc; zLI@>3y`u7r>ioep)?KzqTOmH4y(UqZQLt|v*!5$|S`+u5v#Xrm$s5JnZa4nMu5JIR z_s4h~V4>r98Cr&QM1&Ea9T#%wdzZzC4V~)Qz?9Q))ihJ1>TCYJ{a^_f)Ufm+#ekW~ zISQ~V(4)-w{ox%$&FW9MBFeUeCBs9DB)0M;&fH?iZxvy~54>jJXKF!nFsAXWyJZ}? z35M7YOiWLt9Xr=D-wQ#S*F3D9!`UkuyOh5~D3-efe1SJS9dHT!UJ&sUS%@r(4H^A{ z>HcbG+=aYr--jb+DCM)1oQEfo0m9hq*UKMLX9MU4#uhuo~ zKt!P6E>U0*?A_D3Xf(1&`YMkXm8{R4`0*jgq^_vSSEvW0s^Ye6UQn57Cdn?;zUL$X z%;UZ0f9TC56C+?6a0t4ettnEnH*%yi=&NB7()YI_u}V{={V+sUA!y z2v^PqCO)Q|1?~H`$STJ@RL7A6!8oalHjY_*ekLQaBTXJtgB21~`${F`s2a0sL^>LW z;SKCh7TfU|f#Sc}4=)=01YNJE39-;OuL_NnvttUTWCqMr5Sy76sv5lbZ!#eJ*epUQ!nTR^0Abp7 z{Sel?QYc+$4K7gEN7&*pLOC^Z3Zz396tuscyXpDFG7vS^=t08@o533&3l(VLJ+*hZ z4tzFBnv@cr%ZWj^5IKtOjbXHvTol#!tVZQ9b1iQ2Yzz+5)S@V()ss>O_g0or(7@d^>q{kE{Qcs?W$F?q)QTSMWVPMK|0jgqZDi7G5X+avC2#rLmK z$8z37!ydZKt?LB_r$sH6?V;FV&o5|22$}D?g=6!;d{DS}1t&U>G`M3+J3v#{U8O4+iwM`W zO#qRpi?xnzTR ze=ewY?&y_kf!(iUfgS#_yTe!-R&d;0dGs#ajMoY-Ud3_GCk_7AU7gL!TxEa$)gT=p z+!Dq()Nx)|ihb+*yC_UP1Z~Cqh=uxEUB>k32Al z_cmO-k>j36QUUVH9p#yx>$ZCvk7|4z=sr3-*a3G^8;Qny3+rGPWhF*`WEBZ$fm#AmdXeW5&_$jLxjUUeJaq<~u} z&!)%lA<*{=XOoqa;=)<{hUXeU0xRne9K>L07fU`~-Z)P3nN!E!V}1~Me&C#c%BVJHDkp>Q)s@{+I<>yXO#o?9Q$TT7QXvyP&s7S?K=S+_KJG?>;8++>QqWKKD?s z?@07gmN&5FRf^AnyeJFCl*@kBeX;7RfiE7Ii;S&HhxI7}PX%%ozvW==%AZOZ&Y z%+II5ybP=$q@)#sFP-yg_c6*iCcmkZc7{O!8k1`7~eJ z$uArw?+gkOBZ#oBC|q9{IaCL}rK9;iwmCOvelea}yt>{{oE9bx0+e;^5UWgjF+!sA z?R%tHgyH9)p~la-C*6ICNaq`5JAh~B{VS_a+H{GKdU_PtRk0DRrGWS>5=w&lSGDHQ z8rk<>w#d zUbo*g*Z;J&awPVz09}n(>7%0sSJ&4o8{g_t_@7np6|7U|H$rS|lxUN=Uho^if=uN}`{&IdNBn9H&yI)DVig55k z0As*6kZ}nVGot*~AY{f9-)t>ga{E!CW=bkeYRjsl2c`n}zkA;qaXp(d#@PvLFz3|+ zn}?c5$f9s$Qt0bCGnsR7>DNm9+aF+8MKS6&@nvfd0NY6M?=DP3i@hD4%|7PRTX$*` zaN(;~Z^kcIE-Dh5d+z`O<}-yihO+Y!RzL;n%~jxYAstz?tjWZ`Y!q)wq$$CSyIKan zv?LC9xOQ`%TG3(hndR9~NQqY8^Zw0?`!r7+zEa()`@r++y(`~nooqp&8aF=-8h)SG zf;N~z)*cmbRjE6e%5Qwr%r@|IwHpHIs{lm>JxYv?6T^PP?a1-b3j0`2fPw7R>B>cw z9Z7m6%`Jz`88M0Xe%q^wU1a#kbNSCQ=F1p-$}!T6Na{sw;&LyOB|xG`S{fR!HYOu6 zt6TmjASYSD^$eQt7L~U;y|>9k?@C=#Zz6ZlieWbdC-A5#qSyvYfAz`Rif(R43d`^# zB;**V=A6$7AUsVMkXZS|hG;)qtmhPiR1ad^wj?TsU{~58^IASp(r=6EA8vmn)GPjh z6^zic$E-dn!)XfSEakG}ZU=2yn5IJ!ei)ix1KbW#1N)kGnVIK0nV27j%?~%t$bV*% zA#Z|gG-0o@cCw!x$ycKO)U8nGD8;>=Jso<} z;Mt_JnT}7e*MY{7v%E34r!k;sAl^Ih#@p-~ zQV9Fqf3%-x$grG)jrySy!nt9USWb4f-2P(Pe>r>UgO;boB}Yb^^Is*Fm6qBo%qy;) z*V$4K0QhF}^G=mAsWX~A34{8NR%9bZIpN(W+-j+#(hT>-X5E*tN=d-b;#THG z8{EsTAHg1I+{^nIwe_wrQ@!Bk&SrlcH{a$Yxf>A_d-oq z5mVTFp!@xY;A#L_Z(fy+GEI(Qs$hOzV(xIR>;)3`ybNF2tdyc(|=HA?&N+a5d>_5rKa~J=IesHsscG zgON>-`;<^Cg2-*5Y2SoZYE@w?k~=Rh z9L6yS0G30}&)T1Dc&-N_t%aB2S_JF5DaUzaYy>j~zr__GQ+HK2)#DzUuqtb5fDeo?$xywc0Y?iOHMV>eUMuXs6+@fAqF89%=dp8ew$ zP(YXRlm=-&`Y7TR>!pB0ee`_LG2&k5mX7$;)_v@rw|`1>;sT`bud#(JQv;O`v2m^w zO1|7er>W-?hjPXx{06|4j~-1}2G+S=xm?+`mIp=?7cxDKky8GR|1hCCQN2ehqiHUw zWo~7kd#)52Z(gqw^vPr*Glq9QR1j1~Vf++Lx&qm-oVb0{L{qXVo8uN+tFDH~ls zqyq>bsgJ!387TKK^)t#R6ZsjdM2!6P{yTC4sPTdM%L{pM*U7iqgUwcEEs^<-je3A- zmsCcqWG8kdDz0A^E4;sr6y!*Khd@ zRUFe4Xt*DR>$c*5QS#i*b19W}+(SSr5jH!Ma|qC+iZ$P10o5MVCY=V@eshVvMo+@S z(OC`zmapDan?{hsl=!D@cvALHs>E#XQnr3Pz2vNN=SqBKpBbQ9xN6HK=LQ!YDFhTi znw5bGczUp<*F^t`SrPO1h1!PsSZG zVTgh3vSvXqwW@a33ZD^)Vlb_weueihnXjl1moIPUC=lPU2VyRf)04bfSk5)Kx9mp| zFDdP*RRICa93!V>#o-S;4PU7G`sx4*u~F0V1G(d{UA82uhho&)0uO!m zMND*@b;~L)39@uCOh2u9`Sw3BuLCWBJTkRstz(}p&v1rG8=n)}kB^WiKRP_8O4=?u zboI>?q(OX_T!`peu#ngGHM@a)**;38vv-aS&PLr9%&am9P2XX?$;wkIIyJG#`=^Vt4nez0>`{Pe zB%)<4xCM_-88p>+@AqSImjgMLt%6HHpeXchDJS4t$zG0i8ot@hir*|+yysbLqsL{q zm3_ZX&v5P%a^ivNj3}9qi;B-pEHbRBnOWqwIMjuj~fW=Ln4jV37R zD`t8%7D^fxl0&39A75HYAd&e(=tx1L#oC^bEG95*^+Bncil}X$9zQZ8-0u4olnh$<=H4tZNfvhZT$*0=iq}XIb=DPjFzhQ*?y1CmZY z0l#>ox#n4NS;1&7erb$fh_pvo6s)BowtjN`+%sdqMk2|Va(~7v&MxHpgj=TFvEv=RV z?;@L&A`EZ!8TWDLrhC|qZWe^tb{wf#zl1Bpl{o@MTcGr@7VEx_w_^ShgJXmP7x zAt_o)!YDIGBFlEfQsQ}cY}Fy7fL%)YFnMC{cTG9N;ps&2&B0(CX2V`1V}h0kmlQGh zry{T5UT#|om_O)?v0m-YFl!4i(#U_BkzUd=6Xv%p3#k{Ej2jWMH8JMK&>95Cjg-@p zY-d0w894dvm8eD-J9%isw#o5#p5e^5v{*RQmMyQ>|Hd$)pe<&~bnSiKi4pp}qfCo? zawb|ZMm4SlT`K#+itpcX%#POXz zn)2MO$_A4|yDs?+K52L%{KJXtsi<=4bLcEWqvtt$vLj- z@kfQ6?_WK-)LJ*cAftvF9e7fMyBP^0NR1pLa`tKdCd^K#47-~7r zLA2&8mHPh)Ili~~1>Xq58I^zs?ht}UuoP7Qo0a~h#S_Zw$*`V^;>q=S1O|CyYb2M=&)~T9ts6b3 zbAPB&1uV8xJ7C3W>z`yx?Ld_FS=^sz7s>?uieAQo0XQBiALNc-zUf}6$2bpr@zs8B zM6ahYH3p!32{v}S+n`#phg4x&-Thih^&nZOs4`b0Ow$#pMB4XU1|tR%-P(Xk)ly0u z`fYxBu7&x3!VS1JI^9fm?r4>|M@15;z-XjG9ih!rzi#w?(@SIH9W~$zGdLRL3LUD; z6@<_$1;>%f8Z}*Q^7n2r)tskvNak#{e^BeUj)Ntbn`u5DRB4VblXSH5f?jA7U&x@X7vOdb6Ibh_pZ)AH zj%4PjUbZ8oZ5B%OEt2ndS;K<}ONWn@I~3G-0i>>d_bT|(iqxoibS#r*UFHpxxz@X| z$zr8$)$LyxGVYcy(}h{p5c!+jOVG#f)a-qm3(HYvZJu3_1e(;9dAjOGaI`u=US+P} zLf|{gKu`SQ4+RXXh<+Cishd!8Qco)IGwUUNYrS<%K4${K`%qdiH4>Ir*kjUnfy~YSZp|aQvHezDUsr-#agfXbIe*EX%NCs4F(<9Fg>K)J+o^wEBeq57cIwM0$@t63CKsp(`lh=GKo6Q(?j<4`d`XUttGFs{wICwO_ksAm1&|n6U~gTr54hDW~KpE7^bu;wX&GubvRQ06Ry03tb>c$934Z_XDEU6KD2?JX zdMWk&oY|=QpP*DeW|0*a>(f_{FKSYS?G^UOHPf=Y{3RQtY`4*+=XW~h4UIz;8!Ve<6TMn4Z2-4CrWJNzqIvY)jTJ<`R8-HDryIf89tI z*0Fwew}7)_eI5%A>I1a1AhCmY1|vbyhZ|0dLr$pp(LdtB!N-5J<}&}Lb^{T?pgoHv zHu9k3(ah2Loa|QHA6E-cEBT5j@>aHWjHPCq3-`{=XX9 zYbB7gUaSPi^|WM#t)AN-G!ls3RRa{DEb{kL!RoHuv49AzYrQ(peUU%mx_n*CK0M`lhkq36D`+1bzUo4kQ?sFC#4VH_fh79bw{MTV2mSub6P+D7H>KHFhL1EzNTMRpp~# zx@pGMpAP}{T`lZg5`s4|DgUU4V7lO2XiiqC7oe%)GegfKy&5;*;EA>}M@<;x2`43H zcHV0s^u=yayCGztF;>q5#gofp5p5MTB}A5_knl?ksUX!{Z$&D6kTGOm2p>J!{>}R= z!)~}gRTV%}hnLl|!3LF6%Sf@^B)prF6s-nJlpA*Hcp2R=gY6YS@%~3apREPeWgcVV zipu+)AL7mlrkJ(GbwPyI*gikc7ghH4R}t*lcQyx_#UcvB0*b@2GRsHtPL6k(YOG8~ z4QKSf5$r@|B!z0NqDD}0%wo{&w%1cRGv^C(>0-!UhO+a&t8b%Jt=}+1ernJO5|m4#D){8|W3`d^tDmYwk)swz3ZdA!QJ z+1kRZaahEhqzGfI7q7(U?g;E%q;VI7|H^O0)Ksw1nOeO`B)Pf*^(T2TXOo;MACY@M zROR+)i=A2{xW+ZVGt{jx^2rJK^xVcLUO+3iBCW3Zl4!|iV+rsV88eex~$%9 zyaVW>*T&V&hG4V;{#(}ygiUpMZlD5fj8$ycImYEwutt&lg4tyl13epsV91035nHev zIICy2Y>Rs*atVFa72EfHJFvDia@b`w#BWqad@B4d$Gh#LM>*#fY21u&i^q zC2Jmh00>~DsmUq@d(wm=mT*Mhol42C@~}pmgVze4bbi&k%rdJaNIik0+wyYN5kT5p-@Kjq*tMtCjx&y~Olgo_ZbB@aN~L zOQ#mjDqIE`=0>H8#U+De7xH}msyZ8Q&&y*vW>u)UAX?QoGxb`gVK0zC9_4!-Ei{yL zk(Z2KD+x+&-L{Jw%k@c%pt$N4s+Lm#h4kr`Ncj6fot2{V`*wp8%3UkLoOP`2*QRr7 zt$;Q@@ULCJ*cK-3PGklj7FE`k9rY@;3?u-B{Gj!@<(#?Rf7G^YQv>q`%Vhsjb&TuZ zn2Q%b+KI{TyR7TG4?sszTmSucTa&K=N=SkW2ed-b2W&6R03@l7q#n8Gq0kzUoP?4J za8m(E+~Z}fmcdWET$o1ajXrTK)v3msg*Z9Wf``rQPydpbB=1A1DlucP%aIXkapgYG z9{CC6DslL%Xew#;TX$Y+O@zo<6dxsQEoo?whUR(>i=0zPn6DYA(mBr(=_Aa_Te(=o z9_mW!d173S%RX`uxdceOQ~O}R=ljpHNl`vqcx=sV z@11{2@OB;%CqJwM5+@^2o$d-AVLgy|Tb>Sc#`IBg4E~SP^M*vwoPrX_&{E9YtK>wb z#1uGPdP#=ETfdVyeJ#B2OB9ULM%_g}M0fmEv$wmnE-LUCw11@Q;*)p{Ev3G3Kv$R4 zJ2nKQr&@|?Q1rd+a*-C;vr>-N>CF_zDk zV7f~MJGbhhg=BQ~j90DNOe-yCSGq*OoY{oMQw3_|Q=7)~>(pTvjhN46kLZD#g5G-F zcetf_&AAztJf2oFKzBk6&;gVb#S#4qr=W=sc=& z!?M@{w+gZR(<>lz-r;FUwJg8ZQDqG03qup}8OV!ojXO)rzwp63WKRvylx$?|`wq|Z zaD^tBGivCs`Bh&&ERP#I{mS<1XFr=&ylY(DA&Ru&tzed%RoL#a+;{+UDP-pB*wL6J zIhYSPGR~u{yVKVa>yTdigq*#DwuE&p=Z{tavF0e%ezSvC?~0D9K$n$Xwf7nwCmO#_ zwgl`}_NJGn5InD6+n|t)2FlkALa8HR{!WLOvYxCc5h+{Wd+$#E1QmThKbeGJX`JI(tux43y44R=f<$IoZ9dWfMYi7GAE3|s zu7Z6a7PoXaW?NgZZqjE9=fdiRj2wB;8L!;j>XvSs*4;o|HFd}x^vN%dx_kG$oF6(P zX;rH^oVH?)nRi`z>YuY1NwEKM$7i7VwNu_shfqKsnEyaNn9mI9|IZ_zJje3+6{i8p z2ON(bv$&F9swZ})+kCpNT1Pa_&S>PizGv-73#M24j+@=cxW4(@tnX(aLOQlDhzcDt z*HpvjJ?c&cfBpgJ0%Tr$U_?Jf-6hM-+h17KTUZ*p{XD#jjxk-9at+IoG3|ONL;P8QVsEbXvO`(y)jkShmt|OV<{M3t2VJ>17Mg(?qy+uN zHx;{WT7M|9dKO=xwN8uUpg2wU#ew)$>;= zsT?YSWcrNa4Q{%}Q^OZxmP;RVQ=0w)ORg9LpA7!Ig6p0KvaL{aqehDSm+q{v*{W!| z$mN3kquVK)cpV%0x$u7ehF)1KI+Ga{&gO|ZY|yq}T$xS=PK54WWr!^OtT=PVu{dV0 zZpD2Z!hCO=#cXBfW;oDyNmSwR-J|KMa>(4(unyr%(9pV$PBICh1%oL>HaC1{S%mbQMJDJsk_CH~d;wWj8n$5yU{e~}@v zPSs-Q+tOS`d{tNX#H3b9qm7>Y(3J18N`@JB*>yly6omrA`7rk+VO%k;kYcg!stmIb zQF8~vG0a)}y{R)lp7(?D-><(pElnn%Ele8uJJ@zms%AN!x4|?``fiTCw1`m2&-i zR#(vYZXaA&!}x!)y<@J-9OU2q$0Ae9w7dH}+=9sFtEEu;{_$rvNwJ>$Sx*kfjImZq zS()Mu_fv7*L|@B>tn1L#HYc=XiXN(J`vpY4Wz6uxrc0$>akd{fQ_6ntyf@IQ3%M{0 zFZ@F}dy|#@siN19VoPIvr#{vIgvC}lh#dUbN+m2XBqsj`&hjx9)i*p_2G6>EdgIo3 zS=zM%89HpfJqvTW2d-?04sG_yn1S_d7Q8nmRd=&ZVGF72auQG_7ZLxnkt@xq@5zX;QTz! zEIsa869*ovr-5K;4*Q?#3fXDbsrg|`?0fWQZe!i_uBMNNkA+`n_~hXmT`WgBv(bdb z_D?-SpK^;8cbgSXBzK5fwXtPHdIwj}yITcXo7cxF+}Q$X9Yov~j&MC9o;?IhBduqe z2l}702r@3&Xa#%6RM7$3E5DSVPPDo+$1@_4SEox@nlHyq>RSbmE#z9qE$NA63VP6Y zJBaH`(F#p`jND!G`|v)*M&)-CMhR3K`lYv)^cj~+v-iv`N!SviMR-0Fi;1N)e7FyE z&p<2Ti7ESs=0XCnvO%Qz?njnQ)Ry2MjKPKv(|$Xa#)*hI@+i$pD4!wDW{eV~Ltpub~h-s7|gz&Mw*G zA3nQ*t77BOm%1F#Ya450X)dnK%%n#HS6wIXlPVV*fxbYexl4FT=Rii-*nIr&-l%vR z_yQ6G-Wc)x=3BRC>A>P;lT-8F^+?uI0IDTqLKgf%+Jmu$XIJ zc=|e_}<2bG;?p+ z$Tm_&qWt=(%AXzxvrTAg)5Aodsls}?Y9I3FhIuYV@6mD{Vq2@EiUmP#&-XzbqVZL6l<4U*qc1`}va-10yu9oSF|)f8Ssrij(4`kPQ;Kz>bL zSi_$CTT6YD6S6K#>h2W;_bP^raySS*lLg7p=<@fbMPf(~3#s zYM+j|6yI*uh594@Fj`^zfG)v@J4EItp>RI1CVx)Yv<FXAb;@EKpn=CA$M9rMW&may8^+Xe|KIgJvyzMkI}WvDKg zJlWJf5F*Twt5oKfgt)js z)rUS)pVE#!hW}POuF}20oy*?8f0enp>Y?UM!0VlyB%I2PqL7VU6y!p)Vl=SK#ED-m zx_c%p#%%Ge7ysrsfw#S@9sviQ^?by^r0Nh79fQHo7fNmc7pul!w`oDOo_~?|0z&!V z!<(YbA3QfX9T4a!o3XLwV|uZtu?I8wb(_#Mv%DPK7Ipb6aG&aDHEDG-KWQDDjvdt^ zTKoa}4Rn9^V9FEJxz)iUH=ZqMAqxVM;9K=l0b6x9DG-$vr}~YNt>&C4s2g#SKY&rk z@fNimxze~UCSlj8zU9K%TF{a=qW^FkGO5WR9KB8t<@kI_WhK;SCcGR~jck3qOy?)H zK`E1TCov>yfDNSEb+sKj3$H?&=-9%|;GeR{yu#&a@r=d@LmN^{^M)Q&H^;=v2^}IQ z%Rd^fC2}!TEVvg7o$q|)yQ=6OXa+jQ)an+2p%9(Ohrqoot(9^kwBox0k#5BCAcMnMLfl9^2&k;#7&clXs{!yE zqpYN~Wi30((Z=V5qJX2T9-1vMpzFoFOlM*BoQ73L1X`=PoIBADXGwgB-jtIivw06V zhqjB)d~CRl`)iW!ZZm*%3^8g$es{LkF*Tb9?rhc0FSt<_24X4p!QU~;OYn=)n@R4> zQ=wJm|61L-dI@yObDe0qCRrQL2g&?2r9CvjlyP~7>41upKgs+&PulRBT7YARR~}3{ z<6GqPWfb;|WGtFYI}0djtZK*hPUhFw*?Pa;-)BdDl~t_%VhqV6PnuX!ufLyVP?g|y zY}GdS9Z_8G5wyALy7!yVz8Hv;VJ3b=0W=oka)SQB=RG-f@j6XTqvq0d?_b12O|I>? z?`&(Xm_LT};zA!-Y`mGz3f4S69lLGfU>GXFZLcL|6?!m(e({OsZ*%y&RlE$h+4FvJ z(uC??ibrmNXn&;ws*S0p*tIEKe<}ZomYqTT!uSv~>(TOVXk_$!SIChl{@3Rkzlo#@ z>WNd_AX%mgC@w@0Lnh3e6!3 z?8AJiC=c@iiExmWRKSC7MAU=vd5i68|F!dL)%H`HdC{Dt=f$ZN0j%MkzlKD98?pl} z9mScI@_hzds#7Mn#j&y?r5DB%zhv5UzL~weA=?m(_QK*!$47=mi{?dAh+5zDv6yr!Jsz?w~*sLFRQF==`H{3UBBSLz3q=`tW zrRzU`6N$`BD~}f$)Z*)vO?z%Q4@*WZ)Kx?iw!8J%ud26bD;{H{JyW2gMx->{MkwXl z+)$oY-wy*oqGTf?l*Q<}l{=q0QebB25XW1a1{w+7aFhuqjRM}a?$B9?oqMA>I%GrJg zrH1pwZT9&8fqH{_xIYK8^qv02|0gqG1hX>Bu(2=cl`G2)UTd(na-&cQ99YV(Q>t)& z-QDTWe_y}lQtHS>W@rW{S;}d;Ue>==Dbv4Jae$;cpFHlwnL>FRu@J)FFRWWn(<&6z5_7%BN zFN1aJ@>V_6HX;8n74*xNhG&G7b#N}+AQ!P~!~a!7-eKzZ+uyhB;$uxIh9qnoL;sV2 z{D2aGpJ#+LY>K;oF3K2u7j{H@g{y#G-mN&(cjFpqWeb@ ze68aMg)gy%t#0ztkHx@V3qy2XNG2 zGs%E`1PnO6f5%ymqs7W_kmIHCng)MqS6v?T6Xm-ekpT$cuc92uTE^VPNxeW5{Tbot znP<`bF%_w9zQa=;+6Vq9ag0|;g~ygW1cNl2z8$YTxOKO-buE!8xrCYGZulVMIlUJ8 zcW49AZLeB}W-xciULhd1sHQ`!0NB#pZ>r-Cu0z(#`5!%Na&b*tx-=ho0cpBP4(0-x*!sMBb}RS;OzU!vMU^09vy(1b*g%Q8T&=eb^ekZK?QO_Bn1K8 z;)q2clm+IpqmIN1lRA@ zZAOven+D12Z*rGQOlCVv#K9&h7hGrBJ9)|}Q(p_)tHh!LULg4SHL#Os=_LgDkr_<0 zKGnMzx1n=T|H`9LeqkJyWdSz$>f!K`IJ)we6=3jmu^JunYu#V3I^J&#pR$auGR?s^eW%44QVt#SJIwAM zQ!n{fM4yxz5wu3?isVLI>cxA+WoM3su}pfZ#q8;6B#2Fd(p3#3W4n}|d5QYJk@AFV zd(_#uGFwJJuS0tqnpS-&WN}W8(W~zIW-$t5ppHcDU+Z`zp1g6{GRHBtF>Ld(X|J;s zsSl_cQ%F8uTr#X}^VqE*DpiH&&fuAhMo=-1N3DHo@K>g^>_W|a;yaf_O+5rh`gqbx zuexmVIe<*OE#POiAGY47av^OVTCDfmu0c3J_Ot$EQ-W?VDA0SK&5@cDf3Z&ffCvbS z;lsPc+0qloazW$RwMTLB9+MQAo@!o+KB>@CBn=~1kwQx1LtK$G?VIJfd82r+FO;+M zgE#t^4eykiB<;zs%=SQu2UmMqoxj-}Y{Xn5;WkDAA(Aml{B(YtnOXWA6j;ETetOE^ z5-ArG13Tj-=o{AX!|E|NeLCZ=AWBF^E^KT>bx7Q#Cepd zTvbArdC%u{P4f`FPZb1^FLZcXP_X^XjO2yuHAR)+cGGzC&G>6I_#5(S!ltDeoh9RI z*E}B4PAh*jx>;M=<9~rL=OltBI%NvTl<+UVZK&;cog$0J9~U!)U6ZwR#Y4`_dHI~v zJi*b;5LkHY^tFL>U%h(KVBdlTdq1J!`~L+wq*Y@|LxXU@ZSv>EV`6~B))2L(`8|Mf zhMdtI=3MWcLrOw8&#{MSa4eeH9wOZbip9(a`hy^#%wCY;0R@$d?v*kfJ{y}n`;+wPYD=G?${EX6 zz^Lj_KpgS-2X;N{Bu2cgoMX5vIZD{C>pN{o(X-2aisb1LkfH86@ z&+)n%aaDe9?gF={Urmngjh%#L(X^;X2YJ8%(9|^kHGAiWHEN=tWgngL#W;L3-63nW z-b=FR;hxZPBcT7(b=IEoea$vVsn!tw1#IwpcSEg+!usKf#kmLV>S0xbt&|@BWRD5= z;9Uv{N@bsTnOd^nH26GF`W@1cCKPP}QDq1fwGB^>ML4o>@Ho3FMg1%xK>VNy9<|)XsfX2X zka{voOTI+)LIrfczR|hEm=#objWphP|8@OK2+64vpp+i~asTTlX;m)2$3v2;0@Vwo zdmzgnL#t=_^WN%rNlqV8{QATR-KW8M4zMlXIM*ZOva0-Cb)kLLWQ-+N#0=0VA~CMrK6UjugKh+eQ6&(SK}3 zyMVfc<+8PWAX(8vGf0nR3S2$AI#)DzzgnlSv;Y2vcS-fn?nCytPD4Mh-W$$sK?_^~ ze7@O~0l&l4;OBm6&mW<8POIncl(%wiY=?&10E@68e8dc_TIlk$8HRKuh z_spntvSn0^iY1cHIguEf7VoqKx>FOlIE9(xL3Be zdk|Od{@}pzNu#r|S~s~VVX?mOCZPq`Suw$s9GXwWtq?t_jfX_gj0#&vGMxOdNzl^Sgp#B_{kH#zgF zJpD6VAXZYgYWC54zM>i=*`uX*8OB+zd!G+$JL*+q?X!}Yw~s#*Z%0gB!b67aUspa~ z2S`c_+4T-d<68>V2jyc}fhS&&TP~85xK9InRl)s7I6z+}{K@HPF2;%J0vWX*0)M@p zE5}nLNfqSSNv!4aduc3dqM>W@V52a=*X_yp)?W$6mkf?s3sieRrPz!Lj4>aCq6yQN zJ713ZaY~-Kc8lAehiAU*js}cXYkB3)N~jvTtBhwqHaJMae`Yz7f+qGAb;C!d=FYAD zviDskC@3hs($^6m;ksj->$~;|K`}IpJX|WAqSd1iy6}7@ac-aAQFgobNCl8w=eELS z3;$OL;HIh>p2;*=vh@oLn8T*5NNvvVL3W(pUPM6dl!%`dnKqO*>Q%6L!n!~~?Z<(C z&y{%l_yvu4;8;&H7awQB-@`ooqmBn2ioR3LK5KI>I!a=n;gph5%(DBTwcesxlP1*% zr9w}{>S1nB^rF>+pC#Wl%wcEHWo24}#N2cvzi`KeiX!t%4(CNc&u#kLh77$6M|e%T zq4ztWNB<#7_i#Db480fL2)tmu9*v+p+c_77gGa!j_1Jy_S-f>9M#&ct0Pi>XO6%Nb zGSV0_4t&s2u1fF*t8&=mh?2H|kd7XX0P_{v;l!IGdS3RRK!dsNQ_}K(^)ffm0q$70 z&X$2?pIou(p{o?o&#Up~MXQ%aq5wA&mXK2OB{oP<{j>p&wgqn;SfQ&M|(1j%Uvkb{-*B1 zf8LmKo@+67>>db_#2rR-mbYoxs8}@M)*1rkzE4{&8I)F;B*vFw0yfUEKb$OGapavW zw_!CMVMPs}xfR&ooTsHys}*l(VYZ|0GNo#*R&LWDKEPDT+>J~K%JEBQ#Q+abw#VMz zH=Ah>o|DPRj%d}0kfRIH_$Gt+6!|v$EP6{Maj)IUL1yP^pz<5*ITCygNntiexc{UL zSwttvko1{ujk!UMQU4+Ala*!Hi{qgd<<0fZvq*C-KP}W@Q`&`cp}+?g)npH-{EKJv z-a}q8Hjpz-O(-W^37>g84(B)&Ik!-h+e+dbIxLuUIwCTy|K^CJ zb{EWf$I7sBEtm`sj@7z2ER$#oO$h(TIXGU0BLHw-*rVS1h=S)nXqbnpJ(|NdJzqYr zYBmyZSFhQ0Bk0S5MV?#Av-L|IcvwvR*FQnti&N@0NX-ahO+dkv-%LK=1aTn54_+3) z3FYC{(^DHO?de|bWy>(9e=kp$x?Jfjp>*$|=CC=j=f}2ms=duy2f<;o*}`gMmnUMx z^}`77f%NviS=VT139%EFSLMT+RoM%}lfP7uJt}&OPiVS$_hwD^C={%g;4EQY+PP`0 z@Rwwdr&lI}(DE+PRO@@qfdVgF={hAX{+H^{#2^Oz zi?34Z26RQ`%RKv9)9k7?HL7pci*Es1gDP}7yIgL@4)=t!-1Pu8tQI?*cdmJ|!k;VZ zK|ULNSi*mJ(TX%0;#%3?>Ba6gG9`2~=}DrCr2E+&nD2l8igy-Ui`n7Ub4n6s&(YeC zS4wa{|B=X)LoceTdSBIk>tkHQiL{{y&Ef7>XFa0z8GQqPH0_&mm|6JYSBjP zm&jc1G4-i`F>Rr~59172NpcUouP5gN4d4LqzorB*p0;{kQ-6_pF-7UYO!a%Yv=0{E zh_njyF}2{{q7i2xls*o$I*rK78>en6Ri|kex#~4QAqgK8%1&yF{(zTNnZ$w&JLLGz zvcG-N8n6fjE&=oU#OrdeG?~PXKx9v+A?Ajk9ekH^F3BZYJ2$P&N{O(4Ux0C%Yyb_$ zfWyp`_%iamNeuns%pZSBUZ12!)!sGfQU+tqq^Hs4=;nlTs zl0lap7rBDjQInbJ)2>Cs+@8@~Zw#c~`L3?@5?+t4NuOtn&GwR{AJH6<2-s4f4-A<> z(e4oMjWo-3e*QNk3<@E2d3EPup>_1hfZ)zm~9x(^$MC9*eDH zHM_70f)-K+0uK8&BLrNY|CKgx6uTGnWqJ;G0}CMj(5*AcIqy{uHN0-qGr!-R;v}n0 zYzkcB94>I+{HTgePGh)t<|@q#r8#k#h;rYS=WL$x7Z? zzFUyGHodEsNJ7rp!5OjnkBTpbru|xvl(uS)7&y;mrbWIRUMS{F&Xyl{j5sL4*B@J} zu=ET60HW{y!&Z8=x%$M21r+TIA$I>95y`0{i}QE}7~F@-KI`-S3cE=GRbSm(L+d5d zH^0;k`_FUQMW@!@un_j^PFO!FsZ*Xqk@P#PJq2j4Ykzr|TV786&~KZg{ulG4 z@p(wuoxW1CP<+o9E+vsbn_cwFm#Bsh+q!s5)&gX-KPS|vW%xJfok;KA6tu)v=lmWt zk7oM4DUoQz@w2P5Vk4azQ{F})H?x9hm9s)WyT2fWE0@g)d^ehRt?!e(`sLXaN&myA zPXkiu-xD|-R^V9sh-fxf7s-Ds)@cnGu`OGt*{LQE6+gkQQCc$W$F?L&M;;4u!i;PA zWtDwZuUa8hktOOk=R<1soU=O*=_}UX%oDoGnLHlO$G4PxxwRuc-gf#Q|$-~~b7U{Q} zQ68P`#{xUE%pGI9@I8v$LrMs)!7OtH*OazYRi4-7d=@ap`Ok|R!KF6|r6zsDeK1F6 zZ8GESJWoi~ycSvqIhtpgS4$=U=zlFxwXD1oA1|QT`VvOO$d6&1 zY)?}LR-IDt#KKI0Bc1<@;g4A&UHRtlDYL^)-!!J=#BD2Mp53fLmjlMOX6Kmslxb4! zw>xL`HLKWr)qA5n6tksC4Ku&STQvF!&Ftp(y#{+jcAA9(zQfD+l@d%p;JHD)bL=&{ z2loctI-^6YH9!31UHbWtstWE#{lljNl*BjPQY3bc)j-;rbca%HMHfK?goU{RtmILOZtR>DQTLm)I6xvy z5ES5G_+kAlK$!k%AI0)*$Aj}Ck=MfB##S)#4hfaCGd&N9{9&OsZ27N0WZ0WJABpkf zJm;lX>!=t8wm1#gMZ{G;QANfZ$e~5y00gEdv16R(UB0(24#02nYwGI zwKh!DNaw!Mc9f&KRLa+Kx~z0?$AsuLD#{*UYXn-mlukit_CTObbxYE=8X{hzq#%5X z@ccI8Td?};U|OO`hPJoE(X%Zk8+-8UU(`G+HWe$yfi&H|KYd&7?huZ2ZfG$7;N;=! zKjlt)rCdG|)=(al0k(^x=g-3uG_jK!FM8UizKfv~EXMqY%F025cH>$PpQKfs51t(7 zC#rr+OM7TVj+DCRR_b61{X8<3m8rWmGVR+#n$hYiuauwPAEkZjdr+-4vN}3uyCa?O z0ld{o+>>T6NgHA$urlYxjQbn6NOSphp#%qWYH?~ip>(lb*9|cJmE`N?WDzh*<9Cibss><>09_E9 zTKdOm>6OsJ={IkPcgTiDDFbEKwV%m?K)Ci0Byz+fdZ9^1SF0U0%jW5xxEeV}z3%Xs zXSQ6$e*vU?D3twVEuP&~rJs#z37waaBjl7m5d7qEa9*es6o;7h)vSAt!TD)LP=8S@ zYqZAXjqe<2wcH26KbpBa1~XGJ-9^8mrxcLLjRUOCe{G?u+jvN?!i{;t`H0HnIpaSG zf^S+X%MBDG`ERAC@H@`#m)Zw;=669aGEXuo_}BLF>VW~Tv-TFXi^08BXQ^VkSCJd8 zhG!5p>ow$=-gaGh+NIeGjx%Mr zA^~tE?E4acu%!p@Uw`d+<>dP*jqVun@qwz=x0O7^5$OE>j~8g2JBh%a6c!EMsGs0k zSu7q>H$@C4i<_b$-_eu^?7_D_I&b7|#y1HTg{TiR)hURM7wT2nCnuvnD zYAwEopq&5dm>mCvXzoVRk(005E6hjw2s9^6jHSwPz}T44yQd|(T;^}mDTr^_j4rjv8$mQ4Uu(75wLCr+@v7XcZN`$)S1(Gtl$qnpBK&QtIgJ-3kRC5$f7r zL4}P;0N;cbwo<2gDLF!2>8J{1X>zz_eNtLxepbqB!MUcdw+(nJFwu@zP7&;S{r&t6 zlB;3jTHbq2ADqx0$TU~%K*o9Jx}0Df5BEC3HpR(T<#@q} z^xN|v1~{Nbhi6uY3JGI|Nv3Xn(Gsm{T?zK5WxhK`Je`!}(p-NhDDOx`t(th#){+oT z28NNNB_zkV(f6?2e>S8%0AIzsG!q~k-p4|ur4)%sdyW(;NzcrGJ6X*siuY$Le+A@a z&8FDxvZ5;1Kw#=k;5u$)fBso&!4tG{ng634HicE>?5Ws zGLo6|>d~o|#@C;<0i9bhk@G+nI4?wWR=A_2C%FVS{|Im@)mPA42VCzdyk&i~k_hQ7FgW- zH<|I}&dR7ZloRNtXp(gdKIL*T8RZOQ!rP_%x5V1+qb5%lIllR@(*5{U??iw}4^6GN z2-$6Je{=qLj5?MCVb!b?AP$)HDnT1BZ#cd$kqz-4Dq8tsfUQL=xD)wJNB!TI;wBeA z5cd(*W^)TY-kK+LNkk?pH18foEpkXAc1p62XX)v~9%qUbvMLgtZP3rtK}0A-cA6DMm50k9-ZSKt6-zul3njuoj9Z(v{*S zTrMRmuk5^6pz*F5Kjd0eF=Lo2n3`nuXpMfc15VhJ$XJ=zKYP4m2C+i6jIV4&Xvys} zn{u*JjMqsBZc2BsCvFFAiaU6hPKi)4o{@J_ytKv29an{hl*FH_*c7SGStLatL|WUv`nLp*gb``A?wSX+aONN+#rL{8!|F7^Og!kPEZbZbgkn6dF^8t z|3j51O~2mIFw|U4 zW0`lm+y+u_mZBG>FI%{G^xNkC+Yqir|Eh3r?I1|f2@ zj=VoG=|2Zhn6m(g^x~nB5;?+8$$JxHCAeFU)J6$K2*jo9*^b-i>C=z51{Ppj%XAHz zl7iGFMSjyO5Ne8hBzFDX1HCx9dBYA?&y`=>K`XB76zszPKcd&16BPMz*zNgp-!xB=yZfJHZ2<(HTk@8L#r%f7Z!(;yERHD!A#ntRlkco3dkYU<}Vv{~ttu_S{c2A3aFtw4vp>*t#+ z7rgiMm$ZBjDU1#p(|NN60H`5y%k`B*$+kvArR?`$a_OMpl6QM&gQODONnQQ)x@_l} z*WqjKW^(L_QnpWb-CA+CeQsBJ5pFq238hKY77QN% z#-h{(J4Sc-2q7L?^ryJ1bH-SDKuSwKbZYPonzW_&ch1aHY|9N|@QPy+ zt@EVxxh%0m)$Jv^$6WoORo4sSD}7r!e<^YGai?_q?u+FMkP@t*Y%>E*rs3(_Q>0@2 zCz+Oz#Zi2#r^Z!|ELW|%I;Q;!gZ5nKydytQTM%dz-ZKMcA9Mix*F ziH0?KsYxzn?WnKSwF5DDB1o+ryFARC4|4I!8|kGmhLYT}DhuR;HRp6e@BEIBoLLP? zF&nmA!VZG%kFFCIoO?d;sMw6d*`rg!bcgN_c}@nQ-oyEBPPNkm*(U;lGrswA#-mRq z1_zaWa{y&Bsw)CE7A=hmA=R2Xa?C3b3fy9cm)8CCCy^s&DG^e_ADG;^?s7wpW47Zv zlfC{^X6)UXHvUl)2#yjUSHp0KN6@47Vt>-nEY9N*ALhd@%Kh)ZF4TKEc`pdzi8{d4 zh-9ExerHYuFtCZYOR=URbW;Nf#1s|*&wi($JW#FfvUzF0-&}3IEDZ(3`F{_b033Ff zv3KJs_ZuskV~a%6fgxhBgMY-r4|*lT?ab4<{Pb>kL->1O-q^nd&o8}xtql%3tm8CO z;Lonw-TYts!wU-oI%lXNz)?8f!Dlci3ikBl|GAoF2!QoJ`#%Tjeg6agOCIBpA09x_ zY<=CI1Q$@XJ*D~(NRddHpJpKgc-bwftrzAE=*FZV{1MC=RxXz3Hh-W7CZq}d|Nhq> zVO6zg^tP*~Cl{~N!V-zJuuZxu&sj7?Y(eGawJe^v*4){TQqpzd;N63I|Mz0=4ce68 zd%Q)ax2uNa=bXpLOM^|VlbK~x)D(7e_pViB6G#P&O;4_Fr%vB4!G};W<`ykeDI$5} z(10pR?AMW%4ezCs8$-171f3Ktu6@xq1Dw(=T6dgoGC48VGx^%<9PjFA2UwD?vx#f7 zd_v^n-H(5%hjr_z3KwEi-fD<`Zw-0vg zW-E~ZKdyP)5{W-C;|nWxjctW+U>?KTXpx;7*qOP@^4yTIn%o(E7zm1IcwTV`wPF;0 zxhNkrxHH&gr>?u^^JD1Nvra|yhk3{H=4H>42oCDz_8eYRmHwSem3fn-o6xSI-$NwM zY;L#P^)`-A5DgRSu4=H2%$T%VAnW9UB+c6uISWJ~glVXNV+wi;!U-)Z zuED3#h=3#}+NTDSfC|U+46`e-8q_PBuO~l-4m%_u9?u67ERlROJhhCO2xBN~ep1}1p?Dq;wa3X^fS|hQ#VD$KY=t}ix%t=4_n+h(f{Vt6cB9G(RHBgDRRP4{g#WGK2 ztwG23nyx)CsWi*FHqILynQ;4u6!9+SR)aDRvCF1}3iJaq`DUq}`_2AT@yoCVib5^f zL&mFlCCSYg$_eT#i5T?iWHE z9Z~s!B=%KSxhda7UYDMZoOaWKcT?;F66zgpk%2`Kog;6T`^Gu6E*rIbXaw~J2|Z6Q z^3@cPV}si~jMd|^(t&P0SrC`2^Yh{C&{*geks}N&V9g|X8!Ue4{$Jz*(}3_cj-*qx zyoZ@$O)tc9rQXWVscLDGf~pTiCb)d=;?GgP?K7%sc`Bmw$)ra=_cfz$dnJ!nuLJ%( z4@wkYVEI9#MSI>vjtl=-}Fw7LfeV9mbMA)Uy?M|2u$ zZG!(f&M%=picU~^o@3z8W*M$)JP970DP?0*TdrF>yr4%gWBI{z_%}0ZLAQo7VvK~onAvZ= zOXg?BN=+>s%3?YH&@Lg_y_Ny#JSAa2NOWo@!;k$3ztVDSC3+5D8UyAHfVp1dw{CL$ zu-ehHJ7d-$*1sOulR)(FNL-uJ=rUwjw~QU}F<^DMMuiZ@a`(z_G4ucr*^I$MBBwwHqAuA?k*B+RocGw4clS_FEqD@{fh5ZcO z(-!am1d{#j!!lAYC;uFIrI>N`6Ry1Wz6E)f+`dgUIG=*qirVlmOJ&69!Ohjs@uPa( zh>$}pclB+P9L}vR7+4`wyA7@V{LkM?@#2OW9_G;mgj)T#@vE{5dPo@M-pS5An%UoL|@*%n9)X(@gtf-1I^fo#X!3`{##q7b$D0vF#4}EyM-*_KA@j) z$rRmfvf71aPZ^rGU#+r0JoD$+5XRS!UfdwPjFlf0nZbdnSQjI0AOa(DwDr^+P8=JM0*m&Z2S!iDoK!=QI(Y zuCjg3(pve#*wkFLpyTeZc2ev&Hi=&zJq72}?4?+0y#5e$t45-5V#zpQo_Rh4BLAm# zLBSul!l-5Urtd=0#_lBhWu70Juy=qUe`phYF6$rBKZ+#^!7JtNbS(eWqRRzP;f0Yh`V3@Z4>3ekkmyF4=CpV=={dWu_rr2b%HD65R&i`Q-4o;&b!3O z6ksx=&zmRFACd>SIC9rUit&1Ie=(a2EcvH@E8B)R`u9P1*n{n75#Y(Hv)zMAGKzR& z@f2~FElH%o?aQ)1zw@)G|5KWDpU3pwA22M|ux8vEUZmko8{o~QlU_jWHhRmqi9A(* zuf-Y}(W%30?)dD->W-p?0k{;ibn5plC4z?Ecbj?VmrIyv=b2v43qb19ZXJ_EqmGNM z=(4e7|4TEVz+`0g@r{j<_(hL3$I2_S&!K2PF`0dM2uzhm-)9=~A7t3*tx-!0su$F~ zdcI(7+^YI$<)JbdOYWRInA-KMwvJe zxNCC&?B;)rFTu0DMBS5XEBSbSy(GYZ!v}r+E$Yp8rQKI@DI>CfP0SG;2UYpre=7-H z4N+*Y#S3H_-pn<4Mb$6Ezer*!GW{dyfW5#ZK4`-7isb$*8>x4Pz=>sv)+JM^b@^DJ zZe+?2KXpr*0Uhj!Nd`ZAmf9vxtdYIMI&_UZR??x7LRq`>3(}&v@o+{iIlNs1Xtv=? zKEcYptK>%7H4+V`NBXGc0i%xX;uc=HDTV1szPts$RiSLJ{)&8mp=WzooDKh#xT|Kj zu@NIxcEyQPMVC9KB0j)**@RA(Tm`yTqM3E1D)h#qMXKv2BllT3k`a{$MlW;z=9lgp zawY-Cpoz<+>5MecN~WIBN;8aaWj8Rt#ze9VIHCS;RWQZKB*UleP@fbo#P z7)oQPk2W0qHcQG;kFXG6X2-Z~D%F)MoZaFn7`fH1OYcSBk@&89faTed-@4xs%6UJ| zQIL(TXATK-%nq?mqPt!C7Oswms1WibDs;T!a#tU>NSdbprqNKKw|X+&)aG0PhyOf0 z8Wr|U11prTGQobaer{(#sCumnE+T$TSfnCB4 z>5p*2M|Y+Ldu* z3}^o1c8$4y0kQO__dR?_@Ajdd{E|czHkrU6L(p1l8#6 zO3zzzjsvq{OmUO-`!dJb^%54$C#r|7aQTL|dXoj%KC-?z%f*ww31X7$QT23PXkOsl za@-#C+s-B04_z*+-yI&WyA6D0%+`6<<&6e7PUGX?^}nWPsR8V{B#oQ*U><5%1_6%t z2=w#FBkOn^MIm-fK(2NTzPX$(P_RlFIumDqM`FAzW;~;?JK=h-FbS$jKfCBlZ^Y@= zDnFo#v3x)K8w|jp^`Ip22&(?fCYN4N?p-kB43C>TD|=neKYuDEYSC68qn5}OJb2|d z5WT>I&iroZ1zxF!@dwL5*+u6>MdE_^{wkvH9@CS%+BF6tseM^M5#DTxWL*bd9IBU) zQron^1Bzjczyt6IcwaYp(rjCd1@=eXYo2!AR+;Dlv%UjLPnp3j(aaXis6)q=J~o+| zxYbiJeeWkz*ExkAwRn4tufN+kP=#xokA$dhVe&EheW;jiO{@jX@wZii+{86wIC<$M z@Z|_n3-j7$V-O?7n<-8Hv<`QGzP%Y3%}n(cmM0zQA?CX6F-5=y{ep(Z;GE%oTd4b) z<+V%@_Ys&QL+S8rmaoOmi11!j(Wa*g)|hd%jN#!6C(rAIm0Q21Wyw97Z!c7vVNl^L zFQu;iwr;maR^g>gbJz;IKGTp6TU!WfDPvN=e9*INn}c()^p**&n;u3q6c#%Ke+H&H zhp5b%&{4Vm%pN!g??Gni=FK_GHvCA$ME@UShFQZnmK3PMTx4gu(qYhLID3RWQ`*m1 zi~$+4s>h6Q-;#es$8A*^!20@iZaEV6)_|DS8%Mv?GWof4R?UJRIr=iyrB5vmZ&Ocv zS6?p&81@xDmUF_><>pI+MpXw@K_uLnS&WjB{Y69GR0VixM=B%sz`W!@pc~B`q`fOQg{0)) z?VOy<7&tA}G^gn}v9yO}?DKDMsC+ui_O7}R6;>wF2TL~FOlg?8!sIP!1qU(ZDRUU{(;|;H{4*-MdogmFoopD zRH>VY9okw2OK1kGcj|5q~W6>(=ofRd>tHoWe=)zZC~3s;KL#j z1OdK;==JKG4*(NLK@{v+KICVNMCqQO<>7pVTlc!nk{+2VUgKu*KuWKeU16Ax9 zSTKoUva(bL31Zo)m8aO>I_WGcsBSLureWO2`_!KMJQ?11o61 zmOA2PUW?&?H)!6}H-E(#I~_Iw_<3hKx+)gT_I}!FaWQer3h6m4Lp#HJOF}FWIwG-p zt#*t=>*T9Ql+=l?;tdfouIL}zH$KerL*d^j&QHv)`k-w>`)e^ZXT^IyOFYn?|*}Sl>TWKKNZPuL?eFrcV z5&RF3i=aExHf-Hw&<$m*mcB z;SH<=P{y2u_PS6K{Rn~JR2OtY@&bI>zOHY-JRdfFlXUaJ>RdNmd|x;pBfx-{pUoig zzzv%gn+pUK&tK&PMB~;wZ`jp}{IHiuX(`zp&l^Fu07=%Z5uZ#PeD$i|^BJmhCv~D{ zyG8Vd|1k_hd*g1iaF%EA0n<;rV!KTv5v3|Pp_S*f$UQ+fLbD zLYV#^q3wnzE^D$zWy!N4bbUb{j9}4HGp>Hg^sa-mYtze4WP2s{@}iEsM-G_cp_Jv(!KOgFS0xyOVu8GrFbKCK*x0A1Eu+-or>8 zv7ek(fw^6OoRJ`vgB7aD_;Z8SNcE7wrs*SaH;ZUQp*19Tc@T0_kEOo)k4;sKCyH)A z)VCR))qGZWed)SF^Bb~<#1>|iC@Ei2EGbu3&yQ06t-61f;973;Ee3+&qrYT7@p>_B zzqmt(!59Dd98m~K_0-4pct8I~BbJ_B!w9cByZt-`3CjKBby0U{)57~+((PpvSC`-4 zW`R#`47a4L!k0c$DJO$l$b1D??{5=YpNmYcBLA_|YQ03@J~CEl{nF3o&EVwgfd@&7 z?ebP8{_{W6_=LquqJ;H7kZ2dyVBF#QBjArLsFyqF9|7FD1+93>b@tUjq@b%H&TO>~ zp_@sPe8TlX&#_z=hqSwI0L)}xTbP*f12Oyg@V>gY7?Jq8*7;w9->2qx+EaD>*i6jZ z<2u+o(k`NY)Yomzv1&!9<%erwA>bhMY#QCd^@&iuVmIgK2;19kdIbdZ#zsJ0M|X@5 zAIRF@JjAowSUvvLSr}yZF0RTs1`n?4b)m#h3o9uGC#d%6y_=Py1OrL_v0$fu2V!^;GyFN7%<2bur|C3rx;A0_nvmY|k zfMq{!@0z@T_IcmVXV?XiL2fVM@otn2c8FB($PHf;;VyhOM@sI%bZwz&vO)dM_;uLv z-SYr1gYs=|n`i)2^Srp#ld|4DQqV&@mUFc$InR8G)&K+lx#+O_^@q@{k}j!`kbZCg z%In!?Z>Z`Va->Oyi@pC$c3679WzG)#V!Nj16+EQ&Xa#Wi5Vd%W#>;Ef14Aru(bUd%ntndzrXIJA^(eq-TLYO;^B84`G4{7 zUp)L*DEwDR{8uyl*F53ojLS^B7z}3WcAU(W*?#eJIZB&3BL}#m<{om83Ml?i*FfhdnonfED%B-HljxHWD|Oyxt>JoFcrjk~gL| zlh&Vy?L%s6_CKQ8Ow7#L#ft2bfXlUAy158=9{N3)V-Z_!JtW1lf(wXoAg!?<8Nrt; zn6SDI5X(#KPI&?hNDnChkI7xWZGQ0qnO@I*dLUh$9cS0U(5iJzZPUOO_D7V=P2221 z@3M=eyU?xLqcYWa+YBG{;H|Rc#ZrpXR`pg+@}@`XPG-giqhX0teY&&5?`r^d;CX$V z9goRNPBsZrQCq#s*h%FS5R2i&*2$Uie)jc#&6M{;hjw7v^>Ba{(QaLg;6whMPqm~b zdY(!=g-ceOZLd7Jh2=vRnBEGI&y^ftrSN6WsqK*cp`GgQ#jO$??Z$q`m>vvW2DNPv zUXP;ZrvUvnViJPZ<|tWa9N?tBZm^!>+T4K*5)!Tl-d(!m;p{BUMa?>_ZFtZ#DK8D> zXs45n;|7f)r9D%LVTKVuF*yOUo9;{(T_vPv1+F?*#O|5%j+-aFn>-u@#wNUzMQrEj zSkb7vO2ZL;I%h& z@_~4yRcb?nR$01FQUFhdvWfl<2KViHtVVbu9*+-MoXQ=ajU{~znV!)r*! z%ZvzQKPzt@A1XH9ZuT?!HGHWXSol(lf%x#!X#vcpVn^4%w1+ST=L_FoBgwqH-vd_* z6kWu3ul%&(fl3PJg{ODsYIcfll|*f|Jd8qExBMCu$(H-f$dAybX@SdL!x=lQQh$wU z{qSzZdP?AxGDiE(r!hxV_L?TbdH14;V~>E>r|IL5pBr|Y$QJ48)K&&OyLm$wRd0qm z->>l+y;$};K+M8&ekX{F^X<||+}w#2sVZao!Oo^OU{`qY5F>}wXy~O*-eTIW-ks|K4nyDF@i1_hik-tS&+i~YZ~CwMyKgScMUD%eRj z&Qv(S!KYAI$NvdDC}oFMY<6C>?vS$VAo*aT76ykTwTwAr^0{**R%yYd!1_2uVM9Or zyF;tPG+>u{_rJWx^IFoLgzBRr0&D{cQKuOQN9UDA3 z_RqSxp4E~l+b0aqGIJSAxQs2BB)RoWDYYc3WwB_ydmm>Ywg5U`C~lV5scsvP4?AeIRqzUe{=eG-Lj{Lip(W)Zm1}b zI>L_vOXb;jwIoo5a)o45g1b8FXFKrf?V+W?JW97UHOrplS>l^mHV~jc5-D9~XoQJh zT>>e>J3BGsWO*HA>maIGoH&P^0_1j z+3IK9({*j85xqfMJ6Aumn-skYWKGw!$?_M>;T4Z{^uzV%h*G%^SEa*L`0|cC0U)vq z^Ctie%l}4(lz~?i)vYqI;dRJ;%PG7-YC#|5>GN@5YP*`1I1cnxQ1{q5R>R%BA@~yVTr9(ga{}%h=33k0tN_679fEn@5~J*!E>J9dC&Vizt6k*4LNihR5KU)8oYD4^U2I&0&1GUob?cxP5*qH1irZdY%@#f%XnZhN)L7l8~A%a@oyE=XBt2|r4L?A z`H<(umVG_KYQnPrIAZf<*K=yf=b)`b9};bv%)~m+6jmGwnU(brVU4-o9_WA}lrWd2 zDY!pR{^cM$TaTuXN>M*+phudyYp43xsTtlyO4B!}Lsbf$yX80eLZ^@mnU|x1tq&#^ z4!_dTFZS8y1kBC;z)aU|T?dBeLYhtA)4U<6s_ZoD8$GAYu(<}HWl|q%F#?it;Ka=y zKcPSd+ww(qyDn6Yj+NXD)wc*9J=jq~c$JI+~IDZd`;1+%(arh#6{s77EX zbY}+b*~T{f^WTz*^QVHMI zw*sS_trgTQMUBF$6Y_rd3IQLyv1W}&SdictT_W9$cOzPisesNq?H|3|X6Vm>>4T}=sdb`+*P zMA+hu@Tolbl$4KWNC*ZzBk;=uy%gr(9HSIxWAtumX%4ZX zpIDU)`aIA}9hmQ|PK#%V=__{~q*3G1ozdXLL-J zDW-IFtc7WDRyV^_JG!zEksd`(nzjB^q&Qf37HS8>En_Q?*h*0dM5)pBQ|sLmz#wDu z+R`VI8h=}BYS2h*JU|WD^BAQT9V8sXWK$HfP=zo*9Wk2%8fafYp=fOu9`R;ddfqi& zI6KI3LWe4U3#HSt`N@cp*BDW=AWM|2Z{vGSnAUL>oV+@GDJs^qdE@d$+LeMC@*M14 z>~8qy#u1QycOgNnI0ofvi&|uEINQkqyP$x(p@K0~DfPF0igS%4ys+U{AKhV;*XY7l z;HsVd$`o1U67P8PRCb?BMTE4uPI=xm03Pv77Y89z`){+bQKcJvebAxF@ ztadDTWxjD_csL^V6@IAp5TS2<^~uW5@-x$9rRjwNS$eMCTVl@Ra5=x=O3vx2^1~Nt zVFAQt;3ePdL%L5t$1`~B=)(D%^OSk#=I>BD`mqr+c_m3z5o>n9geD06EhUNeur(c%wI8R+#W46J_wcbh0noC4u2vd zCu4KE0&jS@mlExQQ+ox`D}`(ix^S&c7k4lP2Z^5si8p+P_Y$@#3=qp2p)2)^IyNP2 zpPv5ItbtxdtO`6e#v$7G7>DUnq&!@BI44Cv>yPh<2W)Xcu6i|yWpCG7__BrLFOziR zC9BhX&euLtnsaP#HWWJ|;k!-6%W*=Y8OpLn@WDUw5q4JN0NGJQjF;p38${%bSS5L~ zki@y-Z&`Iwx`=qoe@Ch9Nrrh7KjImT-o5(;BJcN2!oaZZW9vhy58uhiVWWxW>Ly|# z!XC{tq#LuY6I1CfG#P2*DK++-A+j7?s)i*(RcVFd&KZirj8{|N;Gy>yVTbj?xB+5T zvk&^~pCdu+~T47Gy2u{7*B zTaWZLN3yOEu^cff6oiRPh{_NT8r4EwEatv@eaPYoF0QLv((C=x6)YT|lQxEXoR>S< z{vvZ$3^6%X?UN)T3x2XGdJ%l(az)bUWklS!YwFh0jkYeewvMqeg_&?~GtqIs$7B<2 zTo*n)D%rP*-$ddV0P)z1$<8G94JD*`sk`@He^D>WFWVwpYEwXL@08}&zO$l32qQa; zp`pUsHdXu&ucvhzL3i3^Y(Om)&ot~QZLvrsTyI(#Ybl%Nm`lxXZ}v7-yI7~GZ;Un} zf`E%2H(a@2!RY0?U&rx@Q4_FBwH;hJG&!Z!i~i;imeq+282Z(C5SUgW**Wr@IxdOL zdsO91*)vc0_9t0|U)c%Sz=o)nt_@Pw07E5H!&xm-=gfN6HxeUYv??1MzvuSRr>eoU*iw)D@{M4C(@07?qzjz!Obpm*JYN?q_F8E zn))3s?S$nP&uzVJm1U*RRw zujzT-ZNwuBZ*DxS+9Ic>`s1WQ;Ob^C`oqO7_c$z$4{c_5E5N*MWlD4Itk)y6O#xNk zKjE5wCLK`)%2R{YrEt=G0V5LU(7rV_SJsNR#Q)ml%O+7AAFH2iQ&e@8B-e?gZC})j zqy>#Rwz)f)uh@I6e$2!MbS-*>xpv-KesE5hKil|lFJmBzl4#aZ*d%P}(#z^{ zu_>yKjR`*2+MdNqk{st;W!2OaGA}dIl52O*tq_jx+mh8{7CKZ?%!#-Xm2m@KIoDafFkL>p9Q3_*9HLNZA;O4T%BAs1 zVUDWs9p(c4xXTn*R=RqiL{;Z$_nmuiKF6w}Mrdm4^009JVWf~gk`B{d|3FQTAxa{q z)2H#(F}D0@O|zNYt2UHSzb7GOfIZ@(SNn)Cb&nucX1By4IPLPNtNL|ca=squ`zyps zU_c=U9EONp17R5Vp4??l^4EsD=<>nLx~sN&l_}}|+m#_kB5J8xidF@dt}q@dQLxL#W?jora^@c;Q$KqQm+nzGZ(%k6dwIkdq_2>~giEe2Cv9 zBwNVI3>34bK+FQx;o^lg`*VY1=E*a#OVEj zI^@Q#EihADFzvC969!x)rYZ*RA=0y|_A-0t4k1P`O%>tO(;dAV8dlCEXzE^Pn&@X7 zpbdt-M2H8JqF&+ggjK8>@#K-D!dq!eBw5}8?wxhzin6+b1NO1S5f+19V=Yitsj_DJAR<70a>a@ z^4o(~P5hI)=m}+)S)(|@&BA79Ej=lKo(B2C6#p5O4maq1S>et`hRNl*gcw&#dH;%f z={0qGhP`*R+kPW~D9aBT*^%c_#hc2^vF0Ze@@_VQqvnP09(EpZ9%}G7;2Lhd%JAC} z!`{UXQ@1!Zc#6AdA_jfiUhwtHN&QKvkyD4~b_=>hZmOPlJVfs0vxwX`x@8)5*H9|0 zEKe$zMH%@-NukaV)xxc|3F97Ggkjx*NBWG`8KW2b(G$3CZmp?X8C zTP|(*UFA%ei5b(RC-M=#EMQ$RZEXFCaQ&NLz<156|N5UNNo;~kUYjtbDfosBcy{zB zRBTA>8ZW1{p*Dyr%*uwg917IyP80^VVyFqCxB!lsslHd%z^1v7>Mt~^*E@bdauJ>KYRl6OySS2 ziFL#KBO6NFK-Z3_L493=xe&>U+sT%N+ypQE#IX-Lm9M=Zf$zAePe9EnPZCLgvM64G z(mv*qaGB048$3%I@i8uW`^PAUo*io87a&n)jlOJPVvr`ZklVJVS5Xp=8lkgAfY zo`pI&jd+e;XC@y*-RQ}vq`ys2Rx}JjgsiTd17~6kX9957w7#;T57V;0>gu=9XUClz z8`WIDHp{H_7<(woHs}DfEgulZamT^mqJ^GO-Yp{EVKtHDplO_!fGw8B7Y!QPWR$SP z9K*}gKuiCTajerXHolbHG2fowEPT7zJ=;WruEC~pcJi9tY=ceo`$Xgu2`9Y^M#3gR zlWwf@SJ$8`V3eKtFjm;5PQFS&aCei$(AOJZ_RSdglE+BQQ2~sNJ);Mk{d4o8N zsUa!_zkRSx-Q^I-hWDUbpcc=F?mk&-?=!XG{BJ-hMY3!Gz#Uo+3>x`hMB>7#-ejHy z9IqT2qjPqn&g~EI-_KyiDe};N`p-Fd3^kef6rWa+8!f-#23bK{073cZeB@F6p@KIJ z$K@!04ng{WJ#QLpXXc)VGc|-N6Zbou3=rwJE)}VUS1;E<8cNByk`xYENL$C)1!w2V2*2{m@P%u8wu7`Q$$gbxRH~;0NCrht6ap2g9iAx8{bxjgu-99kx)@aP+_7w?}Hv}i7|6= zkn>~kutqd((lfjA&svtGWQOh68J3PI&CK)`*KZtQWH7M|R4&ml0nowv!ky`+=MiGc zj=>lmTmZ7nOCr;Sow0*4&Mrm-9WNc|EgmuB5;m?PD|X*L(@J!qP+p#Xh9}mZ*nf6Iqy&F=IQnQSW7Eq7T-m7Az?|NQvf{Ix#o(d>U041MCFx1IeZB>gvBn&_w8tam;}Wa|`zVzbdUB z7;0ior`{OKSi?cL_^@53I#D9rp(2@tI|~K!c9XE}kUKv1;Yh$t@Z;Ft3iDwX76Ya5 zrg5~}R+wf{hz3pwFkbY`1w(Dk2J7N>v|yx)yaH{B8E;{xt`BowR=_A-VBcVGz}o>{ z_@D*_puHIPh!f4i+^_8kCjQlI+3NHt!4l9s1_tjknALmd+N_Rk+LDu6-?}&1rnj@v zJsGr5!kvN0Mi80f?s7vN_fh|%*|K>8p1DH|55)r1=MVE{?{IPwT|OaNpIYlJPQ>%5 z>9O0;m;xGU1O27f6vJkS;vN!hW_KPr6{-`C0WV*eL3ez-UwIEHl{MQ-nCabXp~Epe z4eqrt5Cb%;7CWTXA5ZX%5mxTk=^7)Fy393_9cLHM_qTLUv72wnMmH1)s-Z+Lyg}dn z*cQB(XST9-sR#xa$D%OFO+s5ICyc4*T{jQ5z2eP1M6j1ban%w3y_nKn?oyF^njVG( zEuv7)y#dF4;03Rzq7nl31i^96>K$9+$*?;o*F$%U)ZrOk8yU&3Oc~KUD<-W(L=H_f zsUF)Zs18whvt{Sl-RF5mxwDNk@Yq+;P_oY^L&AyR>HYe^t?V<}QSVnf*S%Z(71;uwJ&P6({qVn^pohjn4~3NoCluP`~N=2Z-L97 z56f-dou$l!!;ZOQUik>53!ngKvbxDU4FL{S+t22T%prWEk<`fQW9&s5`pDZc(={6B z+PWqZ6zuVl|2sn+S|gGQy(!kreIqWOFb~x>PS95=Of1XkQG1aaCmA1Oh}sX^H;oZ7 z=U*6I50;OUyX6?$bZU@AcIiDN%Dl)F#>wVP9-9|bOLs2ffnQ?gjx~jbkWP#Au3y^J zyvETVWi^GMh${RTFtsDsj?eACvc8sF8)F}U&0NRD%olCvC>;B#J4s+rJ-U!MjqX-__3xfHU2bau&$i!B59l6~id zoR4=se{*H@na)8AqHFw9Y(p?1*+vmJ=xF|K}A5@@Tpm$J$5$L;S#+X!r z7%n+9X{B(Yn+H@>Qf;G%CQ@tB(XzfDUy3GL8MBFxa3(Crqsat5ms1^Zai*lvADV(Wo# z5rbRk5UZ=p!rwyc>AjSKJyn*XE;cJ&>EdG0Wmj-=TM>j!y2<$Vm;*cnwn@rLBU6Dw zeBUj@JuVA0_jlHg$s@1{8${T;217}Yc)(M&aysQz2I{GkVPXSq4{oUdgoDrR{b!}HV~Jr-bVXQZA_p@HZbqf zlc%f{8q_h-9f7;7ojS0p99vJg@{haXl5y_SbHoMlE%ohp;3jxYNY(5(Wadcj5F;n$ z$3!;irzF$t1v$$+P-!vQM6KxKLOk+^MYF02;R^| z62TXCvFi7h=Z|b)dv1E7Kc<<>mK~u0s|(y|$MYAdh~u2pxgQI(mJ4F;vZ@ z^C=icv~|(Wi@-3UCzu!tqIm*`#5z8BEE31iFuTe!*}1;;xi2N35ruRL#Kxsl3NmSe zgNjQzxrUFSiAeuW1Fh23>J%SJOpQnG!dI&pLs2BMQyCnWzDn~5IkY%-osgIm@*@OS zT=LjJaI;1vU;dtB<}g*%Q2oNhg)Jxs9W{bC1tu4wm7EK`EoJ=EVfRBertesZi4i1C zcnEk@ymGlfR$_o9F6eVAcZn8=1;ZB1)_Oy@rCIjlQfiMm{!`#SU>CO+*`q>}HYPM* zm7-biGyh|1@AM~t`=fbm@mxUX*`LQJdw~)bWSw1R7HUESKWo4q%&~9KZiUjssanC5u0%0q$&GPJuQ5F?+bfpV!2I-@#qWg(D7OPc>-UH@ZvAg-5D`HX|T6i zyj@EdPAhP6Ld@Jua-CTv{8K2ZS&xa`br3K0v!I6ah2n$CZ`qwFC)-d?E>a8mqdDsR zv}wlxk6LmwyVA}wc4Sp}&h4z&2BRJiK8MpFooV?r8-%};v-$2}cM=!G%_m5V9Sdbp zg*^UV^aZdHdX^n?SBs?gGBYc>H;u$(Fbp~xl^*kO?L6?znlHRv&QRtZw%pVW_#ACp zFnCq9>5H^*;2Y`{jiUIAL2x{Lh)UoZ@tdwxz>&{7v%? zp#2Qj`&rxM09*AM^y4G<>@MC6_IB|!SIbBW=W4b~?KYY>_g?356ditV`WDRgrNc80 zRDV^zgZls{7u!5!F-wyF0d3f}Ym`%9o|JF${r)bJo@7get*YGU5mf2i(Z2dG)XrRz z`u<@&8VamVcOT~v0v*6>>ojzNx$5OL3K&tFtGx1}J5eBCt9c#BQ2;I0b?dBUirx$q zFSV>}Sp)owx&BvjvFQe=tGivp<(9~1VStyc!f9-IG=t9%z_fsDvZ95-i15VO|-FwpLH5YG^UV$ ze5-Nu{=ZtS`9PPgZaWl~Xj%#DuxQcmJ90pfel#R(@Q26E{cE_UHwhiVpt$(3s-7$I zq*2$Z#pF!3GBfk)Fo6FD>$NlBGX3g#d*Z1BUq{L-@XGy8A|jR%{Rlas{G-*g==M<# z2P^MKvKuGcxE;isz zeMF%fJ zXcLEaz{D$B3O=DrfGB!DQU9)KDJjxJTJH z{c*I8gYqI~2WF3vPhqS{pCq_z$+h0vH8IY$U+(4Eq^~r}YHYr#1Vr>VYu|8m1b_n|#CLi>yc#tpbJX6uSo z+L2xtikE~QU{YrkO>Rfp<t! znVjL4ZxRB2aab$ei>Xroa?&g`p4Eoib(8#M7fiy7q$vg+h9x=WVCJ{K)rbLrFXKHb z)q!o2$%Dnl4`|+OVQkkedk?`i4r{x7Bv-J$v@c)fA8PM^D4P%qsE%B{M>FqmOckMq zi(RIs3M)nw%rp?RPWCVB&5ytt&a@!hERQ%*^on>+Qko|Jc2aSTd@&f>_cx8`wl*|+V6!Dw~G+(wX%b}0g<3#;oU)Z#+zS#h^;UR*Mi)&n=+OKXlg??&y)?Q}_c z`Zg;Ix}{AdS9}NDK;0ldml}krvdng_rzvA1nn_lEbn(DiI_Ny?uFGqN)5d?ON2(pA zB`acMvh!(qq=b@;W)k(6I?18rAGx*k3Ws^h!RKnwvHzlmQvj6cSI z#Tm?w+Z>l={$d~J?jy>feyC>pHRg-#yV&p481a4rx3Zm8fG84_dbM_+ zc4aCq@JM$hnB$fs3^jnEN$Pv%S$!jV0><|2bDj#*)JPhizLIM`hw{y*O1o)f^(zp9 z-@7Z0f1Vvpep%P9T%e2*kp~ZXWU(66nP)mIC$-NCkKO*LT{OGizZNRw-NnD4rok_& z9*XN*p=>lSfTAK+6#kY|YHINU-%!)YzrC|u)VXd%68%H)(O^q9>9(^M*K+^})g*Np zq^4$^nm$O4+GLCxs@nV})D*|89173VNS*#_9Pd;hUE~95Nn*Hr3lY=aAR9tM;9iG4 z_sZrEnkX^V2&!nJ5bEZu-+nu84ec;?ULS_}TV_SBG+Q=w8A`gfp)$W?8BrN=pqt8x zxV$~OrT*PwqFZv#>EiRvAPn$2IODy^0c#RsZzL*P<~q~}?WPK*saNqgxTfx*x(FKx zoy3&rVo$N(#a4se{6}GovONZMHQW;}F%Tnl5U!M|YN4`=ep+9&i=R$7r3k6pTjnmi zZQq)@MQ#)nSn4{Rf4-BzSG6EF082EM2mqISnjz_ADK~s~8ErB5j}wZ=i}fT&hJ9ur zk7CV21Cw$tcS;du77}*fEv z-;1>8qt0-#ZXdGzlHM6t3C+3gi1HjK*01$Q*j1B-Ae|sRjSTgrwks%wT`F(0l}F%c zCK@8?5g26(KGpM*wcyM7dLmqVu~QB2R*1cRcll_n)Zm5Ethz>tSqZQplCLogd#C0j zie*UGy{DHV9P8SPo_a@icDY@=xd$3!(7Qrf~NWRt- z)(uRmjuV0DY|WZpuxQ8gU}1G_#Zd!lHJv=%?NXVEro<0^sp0l?7`aTvCX_H=lWmv= z^iGeO+$NzIJ5&>l2^O4uyWYKId8|tzMqb=_y%4&b{Uh4GH<)ElU#t*x=ShVG|p|yU}-W8+;~GOMX*5=*qf2pHID$@pBqB9z`PHf z>a>#z0^A>i#CC{9(s)c$)7T2iZ@t-Q!4b&lqvfc&c6~UeWy4JGJZ4&>?FX(zH>l$U zRkWsUdLk)^h6CQXX?gQuN?Yrhxc7;qsjhMbm5s?t$b(QZXQpN!0aw>;XKp1(&$s3- z1DP1m`rvb|+WjwLsqbYz2UJ7Hel0v<zq6^byi!((Jle)_ zb?<=r#My-O{V_b43W?$&Y5Ra9p102)h2v;q%;kUB6MW-EtJYHSTC`q;A8P!5TR0;Ucq zfC6fLS;g+8YJC9=jOhqJ2!HVelF*F{Ghf;SBamG;H72tkto3*!4oNg<#a&Txu8wCB zG|M}s4%%-Ado&Y3>k*#K)$A-_K`?lYJ=J%1!-;;X=?Hkmt$n_`0j(c6lBQiAN@7g| z-2~0Wtkb*&KvQ`fUT0Gg3OmiGj@2hz1ncuZyQK;L^)=e>0Fv%}?gt{)Y zL2E(IYL^QJv?PaqXJcF>1dmF1eB|I)8g>DR>fC{l{Mf`3Xl2h49YF;Z+GfIrUS1*W zGC>I0sh4QLT?ViZVzBR7=l4!jM@55aXEdS$AobS1T2oQzKkWh(n{cP!Kl$$>F(1aw zKESH~wEGkU&&oMkj@TB^ojSlQ)?hUpo>Gfv)npzPyP3EX*yUrj6*>?G0uG+{_Erij zOWdwzKc0$~Ug$J~Iv>HcoU{Veycil2E-i=tMQT+e2$PSo8A0o8u2gjxWzowD5OT{F z=!Kotdc1&{lJ6KuvpK$R1J#EuzpflZ-w8U>BXz4)cKLh}Ibp_yFf>hPNYKbDNE*62 z)%OP4wED86vKl{(4+5e4pw>-S|CRy+5qzj;8L%YC7MRtj2Kz)qKG-4WzPY}>nUpvl zWYb3ffRXWGw)H(lE&oVvWJOT1#uG<(P7JNwMfKC zA+!Wivg$oyp~?fTnuFGaf8f%VBDWG?OnA=7tncmqWOU*n|EmP-d$EVn=0`r(%45!D zptEj}85AZJ?er4SDIVA<0DdoOf}NIUz{f9I>-sF+3&UMdWAKi)?E+|XGmJaxin&j* zhnDXLNIIcvZq7HF)tl@MH|`pWhC~11{sM{=TW#vs8S^3R@hM?}yXyz5lPi~k0p^%$ zMx@?DMjkri^YR!0&%B<0gMu5*R>Wl_`4w2TgY8L_uP1=E4%3wFNls84G8iw6KpBiCvpcP^L~LaJ1jgRC+v=DI8cUDs^aun?5r z7{~Sc8B7m~%dv?^ZOsdsbq8wWNLLEm)A1=%&K1B5p^zLT8Y_ADRm-JDJZ~5jkpc!i(J<)0iL2dt z%qcyLK45a`m>Kg!Fxdu`hJYJ7@G(CO!(`yVZ@T^kD1cf?lKP`M`>a-L0Okbb=0E33 z9Lu2jtyM8OxfDgfy6F4HL?9ZGYc+D;;DV-E40pa4A||R{m>Rrbqh~C{xGg=2M8+5gM+dPW?ZXBfV#83n&uc`Lg4X_|T#a2hQn@ zS1W0nLD25QQgEGY*Q6%l;U!F<7d<%$&7*jTIs3cK@de zC2b=tuN6tJ+pljHf;z5$;Y-Xn+I;*TVva$p@aIg>suPRG&uSwXo2>J2azzcN{>bn56t;W{GDP8k3{@7-8lsy{vJRv>T03@e0*gX`a^@}A zm(|iRc^C?k)+z;F$bV5;KIb#dP+N|%qZ>jL&u>M5&XXQ*(iNnDi`7l)o{|VyKf2_) z#zOr&P^u*ve1DGLP)-C;tCNms+x`LC?qio%OnG@NB=R0S~m4tcYk3IYRMU{UXbA2i@S{|S}d z*xPud2u|guCMgCZAs&^advswv)aqv$x~~!s{s^|9ZM;RnS-A(k9lCBf{1**h-{d1nL|~ibmFAU;p){U7rD^*)A8C$df;Q$i zo!x+|YqYu?#DhAM-sCHS_M&TXLPItP3zjyouE<|s8jug0j%fTmu(ll9ZI$ZLUbHH0 ziSD9Xt!~M03bI1CyC&!Z9q7*X92BYZHAoHIzL?BRtFrxLF_9VcQ(M&TOiK!P$K>3+419XbjQ`X+P1dv$_O)u?Z(8yNtmw_n$d z!cc48x%@v`aUpzR+o&I99esITZpqV5{f6mNcc;c?B=bTZ@(wwP?khn}P2_P64;ly{ z)B)~5O)_ge%}63*mKNe^HkYH5l1&DAmYHoydbWdtRpqa|T{~Y*=Q=kS(Vcz-O#Fw| z&LxP*#mP0qha}Ocgqa#$)iye0+4@=ku(MNEbb!!-X3@y^34dI+$&wX?Od~d?d;$%4RbSV#QcOR2Cp+TG>k4* zxYjjG_mu6C89zeKaZ7gcOVD7+u!Cj=x7Xrh_^re+VyiocmV3l(#0-zBDh>gh%$g z$a=}As~PQ#E6z?z$2zw8DqR01)nKMU?kdNq7Nr;mYw60g(gI!xGM{;7trxHxmbY=m z;x`Ej#avavjy_Y#DOuai+9Qfw%eZO;^ni_{e^{qKM zQD55aO7;=u`(-47)c#Of$C~i6c|GvFJg!4#N^}ftA)>o+R9hjj#g8ql8yTAVuFaW# zPxX{Wb*$?+U}q>P7{%9-&D0E$w87h2{+OBzhJ2<4u?rxBd|%j2P&rar!f-M(O<1na z?&us2BsvscE3?Vt4jRVRyyB*pgKj#>{L`v3#Q!f|k~G*yZ=};+F;I{y+?{Tf-myI~ zH#=pBSuI}9dc+;9-3Pi2Q#H$d8dL@CEg9`kti2$;MYRYu$rFFBF6Y~-UnnTZ)nkT| zM*{Idi<}q{AvmkoTCfQ8)=4cJGz&$m?25$K407xkE*>|Jx|PL{MOcJw%{9$)=`-C_ zm5Z(N$FAmA5vy1Y-@>}Kdm84t(*ZjZ=|{K`8+);(OM}2E426ts*!Pd@FMY7nRyZ|1 zj9AMZF~>boU!e(H%AlIm$;NXjiL7|_d~gIlTE%3KKA>3?t(3$o_>w!!Ry3JkaoD1A zlZ8zScc)EIu;dsG&ilA_-Yxjph(PJ)$rua%PCrIUx19%8L4Z{-AucScSe_j)2L~(BnMu*19)rmUc&(EYzL$& zP5v;(NCH){zt-s-FLqkNm&*UhZS{o; z7uXXW?OF(W1Mnrz8p{EcQk7YO%CeNRT%S^A&VuYtL^=Yq`HuC?n-sxpMEJbtUfCfF z<)iNy&5t~3Uq1LlF!CH)EdAPrfBYP8D9PH@aFxrqP`HSsX;0U`N+FokzB67pH(sPX znRJjln5%}{LFG-wM8-Mwe}_@2oq-CXu<$KL%q;DHL{I0zErkQDt zgi(2k0Y?;$cDyq+4@Tn7{E(XoiV?05o9v{_I zWow~!FIu!dhI`)@t@P9;(Qf*(Ym9@+%C;6@zwK)KO(D0>0JS{9$tDd0+s^uQ}79bS?KLE*LJ`Zlt}8|#O&t?478<#RgS z#%gzCa*e0E`m<&2*}uH-5QjP%#?ZrTi7lpmw%~ZZLBm5q0Np}RhUIn{f*5k+6H+U_&V7z|n$Ndx`v`iWO7q3yma*OKSkhYBM+j06O<;%V>bnMNd82 z`1$}i5+>Q_lxun6#ZsAQV5EH`|B9=Juc9l|D@delRFQq__RQ_}%$G*egOjJjhly zrcv8V?K+(tDiajl$)4S@ zk%Lo2x# z5ymRHso2@-=!bo9mlJfCopWWMPmRzpUk=0VjfCqUVuypeh=D2+Ii=Bvj zx=*Xeuh5Rwan>dQtFY}c(q@(TVn*dQ(57SYgRBX!F>I-OJA2S4ycG>@>Ugan8`=8w zD8ci-n@R1u(Z-7IX^x<~PDUOaD8PkSv3f;n&V5t3GCFzpExyouUgM}y`2S%{{67q> z<}fG<^O!JWI^dvA#X%0ZSZ((y(W**5_*o~O+I^Gh!B{iz1W^U88f}F|3%F6O1if%g zn-8%X`XKx%1De}t=~-2&O&=;3TeYb294f-sQ5v}WZ@PN_VfU7UO+uu;C&VgaFt>66 znAl@&5a#%Fq(GlIG-M>EdsOzm$liLHrk{$~e!fTgs`P}7DW9dXEKyc@u zuAUff(?ZI>x_@F~cW(jbtYx!szT5z;;Y2(IIm@v!V5w;hjtWbV-Fa$O1uS+0xoJ7<>e`=b#6Kb;JB1bl8)& zGc3`64Hd)=qsaxFrzoSRYD@#bXub zIjX)2xw&#lu`%3Ux@SZ_I2);zGk7zN@RqWsng`vdHD7Ds0X=|~q)?`HQ~D~sVB2hM z+lMz<93*SssB=Y3AR5ZkCgJ_QM^7j`)2|WUT!?2V}F&l0c835VDI&2?v+Z9*_wkWz7oB3>CiYhT^3c zoa>(-DY_>Pi2ucVV7Cc`d7Y^A9i~CNbt1UtrM@$d8DRX+aX*C*JMYqtXl6vG@bIAG zHtjB5c8LqML7Qef<`ja#K!L&*4$_HS@Gn@@C7Av%?9V_e$)M zfRH*>{Nx+pi8)3ZX`_N0^`59o?489-c17>;A~0W_jB5_K=kcnUa8tr}vUQ_QK~gs~ zd^@b!W8Lsd1f$Ybu{~?p;4nO}m?GC|a2LeLN^P%$Ea2T(vgg~gFd+TdUSj|dL{h8x z+QQfymm)&y#tsx{lb-YKF;y;pW3?*UGJh7RytBU>AsEBHIF0zpKw9wrU{M-etCt$u z<~xsP__dS$AddZ^BsO?qroS}hPldD4YfF(N^od5AF996OA5!nIj8xZ8gvOT!qWqsfSkUMA`QukSjPLM|4<-F? zV}Y=Jk!aGS!VKH3oAx1pFaAyRk^3K6@KFjrZiJ6m@c()^JevWWaTkxucDv|${a)Vi zP4;QM+3fK8H%HJi=e}|K5Z&!(Q<+$g1^Wt4HU<{mBb}hSmE8$f?9M55b?e}^Zj|3( zC8rUvDro#gM2)+kzA==hg9}8`4vVWtWAnBZm316;#8!bzUIabz+ip3_ECsbm#bo&{ z=;gmI9Jguz8a&3H4l<~QLxDqXdEATJU6NxejdCe377HKx67Jl115+=RxHw;l*c4fC3lq4tLneLK@Uw%pb>Oj{B2EU zQ$q67YIbhn1wX7Ov6UQV#xm`2@w__%Q>pbo1~9hk87HeM^`pC`ReFW!+1~f3_hyM% zs;WK(3mh(Vz#L!m3neUB(Xr<|Yh{DHHtv)w~8{cV`>!pgjrZ@xMq}UO``{qF0u0H#aU#58S(z=ycE?g zrJLG%I*}NYIO7#PH%ESa(pY6Z*WttN%I}!UeWLh^U9_$A?9`lgJ51;mtFG_B0qb*V z0=Y8a^ZLi#1hCw%3!iL2FHM+rI00PTFy)9Z+spW@v#IC3h)SfYd4~k!m=!@&`C5d0 znQfF+(LLfIc{5tfbQ&y+L!W~G;Cn}IY?UNMn%F5Tt{8yua zYvTBRb8PT=1v`fGBIeO)=W#8;uDDX}r6yA7E*y7i0c|M^y81?){RX`|V$o&etzeNJ zQxaoB!op7p%~*HUgI^`g3YV; zo&hp_7F+~X|7>W!1JZ+#Ia+X^YSiqyK5d>X0y;i+W=f zwzSaPCWGiCyjNRpo*qt8wv5(CSaxkIgU3pn7q@{Mju%n(p78|tam+t^o7pkPl{GoO z4yn3%!-M~u;>zi%S(`Q*u2{G--msm%t8WX{#vHoe_qI-4#X}eyf|Jzda>qIrwg5?r zu;3p~j1knBZY)i@wY`vJ?mnG@Rl$ouN0&{UU<9~;y#}nrkK9*>T$lNG-AHUEDW^_l zdkLr1fs6Yhn(E43(Cbh7E}cMc>v%7=4fq}HM& z;Y!?&t~3(<>5)(py32)cz$3aD8ahwqnsgz<5@wdHzoY$j<_vmE%49hd)!>SWq6&)h z9a87#u25!yZ@V6yFK!!^_oOp!j%LT7ESNK5qgKEm?`wYilYXlD&S-6V!5w>bLH?D# z5kmoE%0s^|?$|Hla_w0SLg2a-7O3l$KJZlYlO~~cX$D#1lqye9+K!$44iu&_$mF_B zoafTDHpRrAQttP%aFsv}CiZ~4`r?}raD{m49-ku#Dm&6zKD?@Y`so(9qiFO5Y%bEL zj`8zRXN@q>By0h~j`&g4`Z#NsR}mH56<-iIh#M~8U~P@I*LU_2Y6yXwTa;Dhg9R{} z^5fNiz>8QQtxcqpdkPSS!jo+rQ8ohROg!q7gUERBV*XJnUG-apXzZ5q{yw`Ohs)G% z8BoFa6rOUarrVeoFYh|F8Bpb^Xj+t2e*pId)SC9 z@W!R$2wB_0G=}gkKH*kE8gFApE89P=@M_0_-I_$|o;3g{&-^mJmy;=+biGI1<%p%8 zTwMN*jDS3fmZl4ir}45GYw(R>ya6wz%sq&ER)X1U z+2;lluaR^-+Uv%_am)f&Jvmr`0@4~iK8{{!VaI&CwFqxRUs!mlvxjXbCNH zTUEPfG&~o;d7#AEP0up914ilk#@Mrb&bWTvu+m#mP-WE52sTRn6_4+&OtqB%rqr{s z2;IO+!&%+;#h!t>;O#^?2lZmF4r5vV$vMQQY=ngD-%TKRg2yIt?ItYTXgV5fQGSiC z$@@XaBV`kw91p=vnALdvY{Jmv(VL00`1nFpwSVLxYKwgo8&ny6+$m!a_YoXEf`it$ z`v?vn!2y_uA5D>u))z2WpUwSfaH5Iq|1aCv57hite!>0cGE$qz&SPIqIl1=nSv$-0j9pGY#@`_z+bM4CS zti5SDw3IniTkFc+_4Q9{cKvbY%3|~n2bqap`@3OI>cOq{GtoaS!p~cvo4??!r_y+n1v+z1v&2Vzh7u)}uh5mWY zJpT-q96b1A`DaTG|Bn87hu`rIt=l^o_sh(++t5E`_7XB1yzc*A@_Lq}sRv^+RN+I6 z%4f{&9}so;_jmXuTAm$u0X$o~jm(~@<=MTJfM-`0bC5Gyo;809cxFGdg|t=6Gtuv7 zjKPz=<~KAvdrsN**DUa4@H{QgjukIH44%A!Ij7}WzB?H7uUXG^$!p31gMK9MBlrG) zDtlL5gB41b;Kf5K0j~j9i~t~6b4qtF{#C+3xI`>%KSkQaqjS-Nh904#*~RZ35AAT# z*Hhbqj`x@9Lc*E3L2Q;`p|W%L^U#5u(%#3h@t347!PZ@W-mGak#fwg?|Z_i|Eq@N(B2&xr`7slra25jMv3);c3agEO|ncq3*4YFx1k zR$Ld_%%9vMB#JCdYq->rfYybcVS#IeF^`HmenPnV^mg~aL!Et(g8IhhdT|1mtT{+u zG~or2^hF&lSW!?U?kKj$6&UuavoIryE@{$WdsJ3zj5GL}XS zl80xskR*R7^2!ZOO{0pD!P`juE~&VrD8;5S z^C*xoC2^=DLP{M#AfTWGM2N@`NFY%WP%%sa0tq1s5(puQ3;~jmaCQPtt@e9guXC>J z_xn!&t0j4!wbyX3d#%0p&LC6sG_i3aM&6FEdRnEQ9S|GU!*99nz~XYo#xx0<1B&*f z^u}M_KSHmyW#v2s%imbc-xXBEj$n-B#1g~uAQcu(MPm2Kek~cWz5A*Ue8zFYzB%rp zAZ+loZ7sZrHqndV5CUXo|}HTc@tjt?v&XK{s9$10YNAVaxTEjgO8i&^w!xnm4$QdpIfwLbu}O=*pc z5415U80qc0AD4cSD4&^=?ts%xbH?_eoX{cM%a%7KULh}ss}z)@6#9V;hJ`R@dNX?f zMlnn@w$4PAZyn4Q+e~7r^mZbhQxC=5nrzH)RJ<+v)?eOM6tK%=;g11oEm@EXZ=f`h zSj?-sz_~}`WH)E+_~=){LJo3=&W+&n!f=V{s_f&8$RSz(DYw|~M8`rE$F@6Ge_$su z0@u5DEdyKDIPpuf2ropV8L6b94Dmx*ReEHMjHcLJ4-WbutT}8f*~NLZ2ol;&?`q-! z#n+3Os$LjZxu{I+J$bI`H@=>5atk`L2d_z0Tr!lj^L@?#f?y1L3@(y$FwW(Ef#E$Z z>6|-X%WI$t{6Mp=&O?k*c=D#E$sFzF0>LChc6+meO&rzKr#6fvNlf=@?j18um0LSB zusSD5AVO<35MF^rEFUQn$J0%7aTyG2yef}u?8qLB89pZo;z1`B%a$j)JPk%nG6i}% zGELH-F9q^?mfMEy3f?oih&FF8CMTUX(ij$0$%T?m2}WVBq%eZqzL`h~%#R>%5sgrD zIk;PCf%0D`b+`)3*8o5BKFJC6&X1sM5rv2J$AsJwx1ptTvJvn>5KlAYCdwC|Q@q%X zsCcpFCtfV5*-=M<$Vy8G{D&QeCpF4}N7rA@jZHU+)Lj5U(wy5SSiO8DK0_#i#Y!$l zS+Zo%Y5!_q66152G^DW3#cxH6HS7A4prxvl&Z!b|sU^a_zqg9Z0Vi-yL=Up%`YmXo z{JVv*AH)z)q-Jbw9U0!6Z$!BM^fR~_9L7Pf4VH`sDpk4cLa?m{-Em6tm^ciJoVhQ@wIMBVwT9~oEldhU&C z1S|Wt`A!uD-GThNzw1e}1OI7lEp9sjOc<9Mu(&x*e+39p*CU=S8^K6#XHR_ByfAhy z9-2(qvRTBj)sesEut8vyLO$}nBH2xuC~X8^Wr-mf&{1YY;%knZ9?GbnQ%X~H@ajb1 zfMv62*t88F#K_`47*jJgijMVjHgj|y$%4?!BH-ceG$^Yq%e6vGrd4^3Lc>4(-Hee& zykOvwv}&yvlo*QT9Bd9^T^|i@!LhitRq5j-2-vdUgcXb&@Ohslr*1(C zrJ-Uzt;_mGek#HN-s$Qx(&vT9Di}w~aa1iCTK8EihI&R}=3oA1<_4^>_4!eASB;+- zC<-?@&UQf0;wEK3#+afn8m}Y-?On^e2X?k!Sa4ysS3hvel|&jRWw*-RiLK3D*-b`4 z^xOt^?CL#kaIgG4d8V}+X)dsr;PhWL_xvVA)+tI+)5vShL=)H{jL4BP(m10l_E${= zbY5{-XQ$+u0rY#|B^}y-^^&BjF)SHrdbG!67(XE%$>9B5?;0bgxkTiNAIlR3&o2hy zQ-by(hv*!|bbXT5Usvndh&_u!N=FBW2|=u~-`vu%=IUad?ARdoHK^ih&S2ZWnd^$8 zbUm^8um4|Jl?Kly((G=(pA{5^DMtX}xE&pi{^o_e?=MPI6u3zwbeJ~&YdM>hKC2Y^ z(I>zV-(0x`G5_s@LWAM$o6K%6{|Lqi1E0EwV$teA|Mdf}YihBh8Bw zmcwzA61~eZ>oJ_WrVT7%`RHQljq2s^3V=+2H5d2W?(l5G%J&eqVP8zP%GS53T((^d+BW|Ms+jZX42)L8@KOL)GE*@%77e8|EKqb;kJk21};Bczqt~1 zM6-Fz7jGBu30A&MzyAX0N}U&6L2fp+1e9dp^a`&NZxhEQ0DwPPnJ;R@F5n_?$}U*0 zbdUQ#OVlKU&DRoO8o$4O7+SEcncw?Dp2k3e2Ks>WpzW38=nil*2Cl5(l z-n#gW`vmsE&h^3CwK6|tF@w8Wqx-+<|Hn@SZ+T;^!YqZ40Mz^B6N9&mvXs6CdXZ7q zooyGEecUc5Dhq`1O5s4i6Y?EsUw+hl%8xzYv`*35j51SnDnIIZink^#ufC-a4P&jS z9X{%P3u?%a&6P>Ye(Ip2J^Eyx^a66xXR={o0WTWx@2ZuzwLeKw52ZX_ulk ztASB;T0a`5E8~%fqi6NGiY9TlqR#kNY5yus(fI#=-P7ht>FDeIP$H%1=zrqB?rkG0 ztYbs1|7wkKHbQuzlQr2ayS?_WIcFXXt3mxBfFu`Y`lc%P8P2b7!uXucs%ftGk$7*| zlnSzo4Gvoi$aYMJ=|?bZ zLEYKoCzkfcU&{Is@CPrJdSvT-uxb2HO4zbR?b*yjsgTXd@AUBh(PDzXlM2wX#2k61 zbk80u8gL&4<@*m;g+%PIS^oALpoIP(_3-aErnLKin66%}IZLC+YY$8|f)rm*oVgBn z{-3Z$#cR9i12SY))+K@Aku4zn0S^36QRB{5Uu2h+NE=Cd)>DK@nuq&QIkJ#`emdL! zJtIDN*uHPUJz@O$^|V@_D4U++3~QUSv~QJD4gXVE1M>g)3bzNO`{Snk)$1yKl*rbn zHO(Brzy6xtgM*E)d;$3Tk!K^v0jVfU9bI(!rJ*vc|Hn`T)sp_PC{wV_;S7_We;T|( z3Df@srz=zE_sozA&L?fINX@D_po>F3KDuST5>Ef47`pBXqj_vdq*)xMH~=2m2^f1_=>1o~eLlFqJjCNxWe7wFAg z1lZSq4e6U}i*=ekr}N@}g|nVzC1)6bivHI*>)A$n$-b1&NhH^M7)yNtmHqFNa-lQx zmw?a7l*g|5Ib~A;75=XS>iO}2xy7GD1+F;e2bdv)hA=uzZ?u+zj@;n2|fv0b|Cux*@=xi-&XT)SnxK`paQ7Y z<9#XX{3qc_P58Fk@P-Ee0QzA6qZ#5rLVo+ksm%PxWq~r*S7fd~@+kyzFX>UIRE2K zsDFZpwpsJGT);M2EZh5up3 z6dwKnHcc`A-u3*bhr(__WTFfos~|H;gG=N1P&R4&Ss@ghQwfTl%SL!{e9!tZk)L4* zTq8Kl(lL^eEAP)I4rP|T=n>~~?1PAVoX|_HZ&^_K>ZmQBD{=-zHtfIuVj8x>d~$*- z6D`ViBWdgDxw&DZC%XF|r&l?oSR9fz%b8Vcp(KOC&o$@eJ{Wr=v8gbTWxE6EW_ZiX z3xkI=xy6R^RTU>6A|}gC+a>o)1s(`bvS(=tOXa`!y0CA+ApZ3U2NTV5iuXeE3i&fy z73ODiK1}aWs}_q329VRc(Cygk1pG0s9MrIBp|vzkF4= zQPvVuz!j~H&Baq>MW%IOk+q2qy@Bzz*w21An7Gz}OKwBx;|;DR|}4(_Ha!fPwlB(22lcvQibNESAKBCxrqs!P*CMoYU+_)MKs z!J&l0?CmCKZHnpAVrj6vs~N{RK*!3xZS_#w`-#@rvd@2V07}j!F02?WE_*HHjL>rJ z=wPfr#*-oo$>!wRvQ4zSf<7}~!KoXl&ejGI9>U++B=DwH^^UQPaEwPI`$PNdGk!&q z!S!&V#-E9d;MfecW*7{j2W~`t**}2iyy98zVx;;8P24pVjrKQZrc5o*=I&9iaodNV z36J1#i!N%I%R!6;83J^nW+eg?X{L;J)wisr3ilub@#b?3oS6z@F#~dYx`6)MhTKQF zP%;`JiD*1MN~D>@?n#Gx$qjo~Q1|1Ys+2XNtg?Q?=bsmiK5~|i-39Q}s`@@QIER&( z?zXXx5-}DY!Z{McdD=BjtF0CG2Dy<_HnrYtTJU zsaFxA^hw3fG0s@ofrigv@v&iutrOcNY*h!Nbk<~NOebZqZ%-%-E9-@lCUa(dBc>oG zHE(=>c`-P=(VvdcwY4pV1fYIJ--A>foG5KBjLo%!E_I8J8hv$ac(Tr-*nz!&45~Op zzpxPz#Gf(eddd+OON5_R6P7Hj>+nW-fp5C3Opx@|bT#vG?J%l1@2QlNGY#UMF_jz_wXV;?a|1la1KfKUH>IX~YIL#t7sA2#(b-b6>?C ze2k0p;Ng(|nyv#CpkU3U;VJ; zXo!zpLcX-3G^0Y4l51C=Y{#bpO9*!$zBzC!ZN~1+rZbmtwN6}ozIQo`yOe5@5{P;s zKN9k?{ecjw6DS*Z2pwu~dCG%}WG!5th6o1c_bvOszOm-=;(ZCIS=mpNLJkHt{%Iko zbC#Bs-ji;V-?5F2g%WV8ZpRte@`;%6sNtZFi-Pa7cm7RCi;evah3Tv^*zlYz%u;9; zh0~j%$G1oUCrU3;lL!h^r=X=-3Wb4_R}P*?_=N065o{3S=U=FL)#+}f@(21x<@Pm zH(C}8Mi2;&?nhTR9Mu?YaY);OqN__>Z-qCqPL{V|%O(o9j1@x)$1hE03P!M>c$o}W zeA-4YO;?vI8qAc9(p8V^ye~6@o~ux@>7D>9aYj{^szwC=;UHja8!A`~H0Tb@oo^%mYUQpY_2A()NkF zav~L+WX2_A-Bo3pu>NcG+uSFa^ z=(nG4kJ{1yQ40u$a!t>UR2RlA9B`8#YQBXh8JLhwD4Sa%3_@n&A#*nTYCB`?_0$s= zM;33hs{Y;j2U$&b3@V_{HwU41 z$80zBF?YR&`&ZPzJLHaZq6|At7C|zARDh&bus0a!gaR|*MP#B?mpi2{>yt%D(kE3ECpv?2fNd$+}_xfOOfpzM+M32W1iDu)g2J% zqph8A47{Nm{^fYdvIAcA5^gPI8N~ zbE2rU9QiZ|79ZK1(#q+x9`inj>6tnj5v6((PsL1xaDb?kj;iC0zD}E;UcMvCHhDZ<-zo*}nFJix*u>fik;NoWC zlh}qf;nnt@X-y>3h8A>UjdK~=74)ZB8v>*ctYRv~lH9M|-DNG!HMS|2g(*&#BIzBT z9CFw~cw6n%+Dfb&bp7x_!(pL(X<1n*By^_=UW;su8QE}(dfP$zW!3%laY0+$z`ev# z#Ul*o1evd{yktuxGG!5!+8fC892}6w=aQhC{hScH`P+#8op(x`{N%~#m`mnk54CWf z;pp3R#{jEuf%RNXX3pD)PDUP({)w(2uY`Xmq?*(|*&l>(X2}v9r48n7hxx<_iLstB zn|sLdwq|Pxk;a?zYFKCPPQSt0bk=G0QSX`-mgeb9NZanS=jG5uEZ%x#|s6NHpz~iL0-L9Q*m&##6od(RWt|{<;$Wjnj#qxw0GU zPNjS&ALOJgv36o=#2ok4n7nFDTIC%Weyn$XFOO&;Fh6Neqoz?g-1S^;vcx^1jmV5QL_8QKY9G?^T?LsOat&V4gEZ zzhbQA9Z$J8IQj}q!`FUgY;n7hL%6#`;rhpSA_sK2cQ-!KEzxz-FfmCdr6d;oOo1zH z!IV@j;2$mO*c(803@Iw+Hxg;sMTRDae!AT9kY$!@Opga=i(4^6*lMjN0xIBPyW5oP z0YV95?-S}aMoi(6V^>iJtpzZ`#vk&k^d=P^M|u(Q0E|#Miq9goL(aM$C>_TcLwUkzmhIWb0LJ_Ha@E$rcZJq&oL8wxBQh~l!ykl=IRY5xl=<;wBk zDMjgn!n%zyZ4(osjV@sr_)Qrqo4_hN1D)0c<^~e$cI|a{?nvcUID)8*X>nO}(Ojz6 z8Bkg;CG&}r@9$bQI2o5L+DM}5)YeDF+P|1PIDg^G*P(Xz9_yBpo3Ymrj|E}fiJPiP zG<5xO zub{5&)IeAkf4?@B`iDALU4m-K)na}%dBWR|xf+gx06(Q=|Jm|ThaRs+s!G7MHM(~7 zw-sP4s*OoW1z%lbajUDNk8WmCJP%I?RZAlsm zTX!d74vKeMVxjI8NHiXju)l1~zVF!tCaAd>|83K-Kl!c*@hJICn%8H}>Rc z3EFx97F2k`U$+$-c>PaXuWF5APzPp>Yx&_TaR&RVv|l^vwT$HiBN)2NTG^9Z83y}+ z4+8k9-^;&;vU%W+NbU$GGpr?^gJ{h>^V`<;$8fZcOMa^wuz~vD{cMosTH=;Yf^X`{ z0qp^Lc^)#ZiL+Rn{4%B(QzR>+6`R1TmmVeBvB04-J%4_ZB-&l+S>Oe0PF1xZa#VMg z?sVHI2hp=1vDLH2Lb8lGaHM*bcXcy!bg}H|2IORX8@AJPX1nuU2y64&^twytzJ7?P z*q+{r$8Tkq6h}__$!coiN9QR#=ayiE*EN~1xp-3t_DkSAqX{LKI~h)l)w%a+K7rLn zivZ1EcHzcyyL#PZ^RdE~`Mp5+fZ2Z0f;Woi9-Cwh6wJ9UIDjD1p3HSJ$@+4%M}c!V z$X?(&`#0pC_S5ts!JVTY5NZ05(Z-rKPgR4Z3U$Z=jXV9e`zd_~j^k%|^Y^N;xY_D< zL)_;UKEcwtN;v^eZy`8w8Hc7^C>!y&i_^QO|({peOwKz1`qWp%-U|Tg;o} z1+V@hw3;11Em}mqXvc%8tZIunQU%r*_;V|{pisiQ;xTtKD+iqP2i_iNGk>5KQ&M97 zD`Lftdt02eb5HX$dmCvEP6xI`h;BIuZXc?Jzz~wb`{wK1!{2!L*iNZR zTKwH-FkIPk`i>hA){!-*0&PSKv&_l_MVb%axW$O=WE}RLZkeK_exXuQ@Ia&UH_CoK zUMcR^@ImvlA?k%5n#fB5aD9(OVIFrVgBsVLj|$JTxWi5xe@KZ9$2n!aqFpbr4iLgf zr`Ml;&n1o{w!W6r8`Dk#Y7kGuy3Vie)Sr2ps(bk1KL7KrHLkJ?ks zRXD5m8)qd(-bxXAVsxH_;Xtaga%~H&J`kjgu>vBEwaAdJ1w4djEkO=WIzY2N(5W*f zhmBFyvN!HreYst(PJr+$eBQHuyui@)ovZFf;1V^H7cXfAhf#;(Td^fIv+tTE)PQ$_ z=ciZ`#qY!O-1|z<{VK4!84v^<9XtBQS5vGR;KlzPn7}0i`UT#7e=OqoIR7_$q|W zQ=EYZhS7oz)Ub)%y7;&0(DFs#!lz`VD`efMRnaI2qXTBbNeDzGHoZ$FZ)CmZ;oI&0 zZ1DihWc74h^{htHQ;Cu9%UZJ1&Mn!l7AbCQEddroST#6f0X_z5a5;5{-UqPK<#s7* z!i1;(lkpwcK-3Cr#8%(wM5s36@6a9;mibZrEAWl-MW9Cb{ATdNVcl=NzF%OgX73UIv;D|Df~ zTCOMAQ8{GWDsbbO3_okiYk|7rzg{MUA%GUHVZyp~_zAzYa}7p=_PZl?7c8cxY@w3ClLp#y&}w zRJZd3?BOVas4)$)c@|Ami5F)OfC1cl^U$M`ty!(?yIUBH(Iggk^s8iTz)Aj!T#5%x z`)nuHbM6yP7tJ+Rm^8&2ND%6KLJqOG5~}#4ao|WPZ z^U!1#7vEw&HkA+Fiu#5l_<8sSDB(+=MB@YaS_|N%#aePWL^5C(ZS{VL0ddhLgh*Q2 zas?=iv$%{DVMwC$7);tqu6^Xg1MfK47>nZJAy3^Kv!yQwHUguLZP zvSe|S)TMnM{)Ms2W8J|A$&UI8Zx|%eYE5DdTo$}9E|>A}%6p6qg*fD)tmLx6!fp}r zAgn0hz|psIKkC5X*{G($9VA-I8y+l6$S;4NIy!0E?S-+XsOEksP;1uW1tFyyK zMEJ_ci(D9j|Asym>lYK1xfZ_-k*vtHw$VD|DN1ml3F~U+l-f>yr`i@FjNbW{I06Yg zW&xJN1R=r06q(XtQ6KSGO*@dkMwzRU_blWAG>^`Yd!t}Y?lU!r$K3n0>m}CwpR5tU z!&}o2nzYt5_fI8W@aR=j5y$hZ@$_Q-h5Cy0i|@aYok%{=ylXGGkpMHJS&W4&q~zQ?9WFNqcaJbKO4;{{9Edaw_;7!QDfWSEoNrQl<^4cPqh z_tNo!X6%=>_2LT-F9@!8gtRaEU?d|NJNEK)3?2tql#l{ckH9$r)vIB`KAMn?tDA?Y z7vTqr6E&I>NHo`XGV|3Su^$&3d~Q^h{BWA>G%yOl{J1r96GFY>oykmcoP#O8GvlY8sKh2O{+Jh#k)MGpLLD>XJh;wLDzOrN=7x> z{9OChwJCrW?@Td)Yq(%F-?6iR_#C*Kyr;iek(@#Oy`#Z70i-2k4Ie45H?}zQeL8Uu zJbDW2c7y)FoLhph=J@sEFwcSgjtFA3*c5tN`*qdL-%=JSD)e&=SXXjTonGgd=~R7G zK%({N&x8I{WsR_b72F5qG$?d-KidBc@$vLl%n}` z*$2Y+C_NlhoZX8y)>#2Bp;|Kpck!+9@7EX)IZ4F!2#k%V)FA`n@&T~`ErqgoacO>D zP~>7@4&fy40u@~&FyaLnT_6Y>0ohHwlSwGTud4vLnzJVI;(NJzq`Fx|I9il>hxGc7 z_)cu#j(bzgqH}e?HE_?JmFiK@6X@wvRNtU33bJ8MRT*#ugRh~o%3IeKIA028gyV5> zwXHtjF>mPGJ<}{k+}Lm^w%s2oFyp``$H;o;LL+(l^=>uIX{h8Ez zolV&I&R}vBUXA`C?3a#gBGLqvTdiZ12^@upLfULIxC?gA zJ%B|!c~;UZDD#pkGc3o`dGW*XQ~otAIQKQD z1y{Mg9%&}gT;AmJY|r@Cj%NM){sZ-G?{kiGjaax}7QSYRbO4#fo*_>&*+B8F*m%%s zli=0(P7J73NzzpD|29G!H0{_aPN=Pq=ro}4_L;!Py1)gbdqQjRT5N0yrO}uOS{N4m z;_-N2jyH9$`s&(h`6VQB)3MO1^>545s2^B9BytR@-0Oy*sC*X72cgtF9Jd;1UJf*W zZ8SGQhw~S9I`2}>sQWYcoDP?;opXpZ-BY73Po`pY{&ORi8iYE@*Lz~M^dau((H?k! z!df8dy9Y4UZ6b>-?t$5&3b(@LKZh!6%N0%#vYH~7f-0_u>5DUU;X5B9vFQRur z?p)6_G6}!5QBB2jK;Zt2Gnw{i$CBw%@LUtNgjcIK1?JifSa{9s@^6Vq`|vcW;_Wf2Cs5q1ik6Z zpP6iWDAdjxfPve)a!_M8BC5UF#x^-Meo)bLakvQyzton#-QWyIKB_2(g#D7 z?L`nfBE}RzM#BcWOaYnml-UK3&b)_T!QghNe7$g7&(%Jd9@J?2$g_As*lR!($B7}s z*Wl>ivOu1BQ#l25nerdHe?dI_v84}|D5Os-E+W57eD}z-h+zuF=s~L)CE$t&F}ZBy z+3wj?v4VP*!pjw8LMh*G2uXKTH%npHeF$Q}RH>w}DS-=11oh23^_>IiLnze;?&&n( z^1=OMAnR1()u_XhyFV1Kw4_vg%L5_*10)^5$`Ypn#Ejli{?4p-ATp%Pvr50U#0czv zjQQEG@~%@fO`BCBeVN6Tz)%VnQ0$kRkRNoRX$-~ZKwPP0m z5^~nf1ZDCx6>Bc6(|bQCa&+?9X#;f{za|KZWK~t#wUEVt5Lt1OdEHDhqBl|5)d_6D zx}tn^Eg;f1EcDh27Pp5@t?~lb^A0;C4?hJS9l;+jgnTIoN1%2o$`US|#kC~t7AjmZ z2kbLUW9RD~;T{0LUf6>NkT_k1g_P*Zq!V5+dt~(`na>XiI zO$EEMBdelYu#S=z6h$Cad>QO3ie@z(it*1NDue01zd!q%fx5fiiRtxHz@*#Bv(Us# zKagMkxRL--cRZr%JiT_w@WjLf?LmHbeQ&O|JzH17Bq-3-baY>RWTL7Ye$_HxjA|m9 z4@=J$IxLFOK%2IKF2Z=ol7~3dY_x@oi^a%>#vQF#i|o_%#j3Gioz1#tIXgG#x5p-_ z1$qt(!@U8RCs4tR)+Xo|05?}WW(u`1F)IRgOd}z&`C6F*NLsl$-8v4e0Y*M zHF9LW^K(sz=R8Z@z3x{8>>~9R7F=l2W zlfS-$tcr@1XOg^gq7zUPLwqqfcoj}xEo{tPSZT6F1MD&qX{KUw(pSa-x>wPZ73JgA zgkGu$TjNHI9s7YuA_;^q1-j>MrzKq#%#Sy+CGzr(Cxb|_D=^S-RXv%Uw;$2-sCyVm z*>hheaFsT6muIC;vBkytLtlKDnC#ZiO{UzS0t3n z7P&AjoMb5aM4_K&?z!Cv$&Ea{6oWtRIWs(%is87I*&WkK92oskc|v zd~R=W6M}4LefG$Yu9AT6jJamfzU9!XVGa~MQB2~UN723tm|C zW)%;GzHGVwLq-1lK(qGoWmf1{SEDLuSTO5~yyxrW-|%Q)75(H+V#Cl?Cn*MEjldSnl|9(E$~tykoVSw{R!=g&V)Lu&k{- zdOWK1%XB6Nd@$YwO*RI%oLWYkukd{>uyBnI^5L!ay`BdQpEc*J1Y<&JHG$jimacDAxpSt*Pp^rWUf^xk zkM5t005c3+Sntv7OEE=5hYF)?s@zq@#iXGYF+GS)Mxqd+uJOAh--U1%3p+M;50O~; ztL^irFGk5yXK%zGy~8HGZWa^aF0SQ)c)yO74x));W2}P?t?XlEUQ9tCk9oXzQ?r$G z7<=u|E4-~kD2*^|*>j{Fsy>yDaoV9(QGQ{VsB?Yzet9md%*Z4tS{SuO1!=hQfsDGB&Fs$7pt#OUqx1I+id4C- zhP`#**J{QFJ>4t%^PT$w?^)a*jW(?2_rX9)t=na1MIHbF&(irLWtUfW>_j`+zQRdk zNHiRSNl$1CK_EKWU$9*k>@nBZ;*X-nX?fJ}AUu0iMv|tsK4U(h&ew2(HsYa}!g{K# zghYDnC5lVGX`CV3oZwZP`I9_GzRx?L+7PHjJN;rlw2ff8#9Z#@v4PQh-$N#li<^u} z(BY~ZZPM}>!Yk;Kn9v%3Q7{*#S;WaHm51&&S?W+xp5{+RNHYZv*H=*I1QD+`igx>k z-)T$aeA{?3=XTi_t1+Ivl|FTv-HrFO#*9s5fo@+0k!Oqvt4>KJb2U*q`BC8&qhyVP z+L`yt9#VM|qS}f6#au9+pmYwu<&lkf>24F=p4gMRL$;w+-Np4M256&#uwI$@2&?B_ z^WxC^0nw8(-2W3_@TqC%wV>Via@y2U3t4ro-CiB?K z9vB=0j$UF;z~@I7AE)9=)LgD#%n#fIuvUY1oE?QA$ky;OuXf|J>g@R@^_0eIb*F&S z`=!OM4(QHo88gX;GVek;y&Q;62)7!OeCmiq&@d=cueu6N0U<_fTh5#dF=Bb(HQ)|mwfGm** zvWL!yS(5X9Zki5kj|h2G!1v+8g#3#8npXwjYydaaUuSiS_T(Y|aVz4Jl5fWwu%yee z7~8`}1~ppdB6sO0VItV-rl zPdck#6{Qv4>?$=bEQOutB<&h>B94#08W;#bth-kM z0br}ecJZ|X7?t-;EIg;M2QAa1&JQeX7MC1MJU1!UhLC7Ae3Egc)7e;#lgc+pKv3?pMt;*+O4%#FdU3=`yxHVpM;bNvZi3%oi#WMc%Pj zY(2a~&1!Hi?O=y2wlsw^q5<0dV~8%+?X4h7blkb;O=z%>1Jd)#=bGygX(|DO z{x-drIbTO=)p(px8T`vlDFM9^_uj+`;TV7$h@m#HImzrTXiw z4w{EZ;7g?W?D0k;oR^o7ICs`dmO7b26X=Q*DcSyvuV3CTSL7-U8X&v0GyMDh5hP1yM0}~K2TL@8x&7l;QPfV(J^d>8W=f2^sZ0HUD zHtO(ECTS@+L#>48gH{W`5ztt+I-2z=%hWtNEY^YbTCjl4UXMsTw+(wd4Pj7(}n#%+uw;gIZBLe{?t9|zN z!-ObA1QFFss{P_-TGC~w@0X&JwI(tHYAg)3MK8RDg$EsM{8w9S z)a@wwwfm;hVUsSROlmLXvFAidbNZ^9bFyW&`Oh&5SxGNq&b^4~14*rlraz#Up>ipo0!ThsgUP-Y+TZ4|_o( zd%;rF8Yk}H^uA3s0Dv7Ho0E;*R-z7Pz$_sS`n5TIJplxH1xcElV``6xQ%D>=J1Xe* zu)87;1CjmIpe80PN!2mxA>UpIM4QTWQM(xpR0~ZVXxv8!s#UE|ZahA^YW%gMfiVGU zo4i)qsz*7v?j3Gi4nhQtW6Y>2lCfx8RBR}8j0ni z@hyOd$7iZAqx1u={!+F)I`p!-TVDCE4jppb!ga-qsP_X0S-YUnib!$SKzpr&=E@(T zdJ0~Vc9D#*Yb>EgL6!srIA|H+05?H|ZXBSsyfgv6P2L_n$hYSEiza$_n1`q<_rq`t z8C~3@qBsiMh!AtXhhC|tm#j6`+S=!uOase0eQvS&iVY-68CNR0!D=y^@O6tx4~&lV zr|2B1!Tg~^SvW54q-+FZ$qKj03d2!TQ1!YLN=4eAOW`qJp??kC3VyS)meup~e%t4} ziYg|D`F8Zp<(*N%KNp^Fm+WHHx=|rd=OVJT*Oe@yQe@Kl8bRJ*HI1nu-hB)#h$qKo z@X!6S*B53Y5RBpZ<)aaNfo*yoc~X9Fs}oDI1=+6*)y2UM%yE7#3En&G&nqO$Hgoa{ z>naR!b|l)s%h{N-3_(a>zeO=x)GrMTYjBQ{cGwSj)f0@r*GyshBWr`jW~Ps0GpF*X z;jvf4L4IG_i-4J=47!GQK@s0;I4GG`(~DqQokbvC%RA+Qp$FMooLj6;G&hZtcXe&O zN^WdPT-@`|zj{wkw@3o}!lB-|k zc(DGJSK}kY;_;#8+l6gi9ytmXfsifXk!3Jo115txyps4C=tyi+>2Id zv0QT{K&=*+g}(y+8#FIqIfFeg5Z&TZU#S3pP`=dOdF_hJcuM8IU?GPhWt7tHMXXIu zSTH|-y48HL^=?d$E#m2;ZgsaMO&Z@*dQ`)O1&bIAY~aC)?JUc=fWqZlZ%2aa5gD|h zzST2Yf+NMee45|8i{!QpebUv!Hc$+~AhM>)$nD2oJ@D#M0rMRcV8*RrJy1% zVnQ-eNf;*3&(#D$e~(j95T5xKgv7;lY}VR5AG=2CL%lKcCFW6tLvezL9VTZzbJMAD zs`;L|812HAzRV0`lj_!gmush~p@dzDzu8`OG zdPaw^qiDbo@4ATYhG$^xoTKF^CSevqFTGn+!;~V&cMNX#N8FfN>KJfq6uR(LpzHkl zR#d&niaKh>#R5!qboHoL7T7NKr7ZFL)umuq(7BAU zxd^Hm8m<}yU&Cf1;2L01R@Vl$kRG2kICdP`nzU&N(ntQFCF3ZwE+FwCDY{tQ`MM3< zEJJO3VT*`@L=GdZ>4jHq`B1PytCANv6mZZ-%`HsTg*h&XU&0^zAx`sqG}60M9kshq z7|-&N2gA)*Xq=L!rPwoRBE70Z((mqR3$5Mi80iD(tx+;j!MEr9z|Q+ZF!}&H`|QaW zD_pw-{On;K*|0db|KWV}uvS`3Y6)?HP;gCPSV1z|)_=itIinMgx;sMeO1Wo|J}J5n zIWaNN@?^0%!!ZE59Gz?eU2T$Q%0eRWafC(n4hwIN)5z_w9?;4_W5=14#)Jw%l0?@X z>9Rdof*9k}+CZJrMqFACxPQQ3bi}=cA=+i_+oVYJp+S|;3Ww{Dd^aA2k`EEXV~pW< zJ*CUX1|Q-0zYB5}5yAGXG5yv;FM8(7A4@08xMOZ6=UmL(kWZyZ^LzPwn!P+Ar#O{A zVye?D>G(g+`M+6d@-HFXHHpMr`#5k?#Tr?{ZZR5Jl0l6r>Yl z2t7FriQF3TptG%4g-+Q#G;o&`(bg(3j5cIyz+QVt%osH=k@Yhg3o8Fa7YX+-K&87$ zMI7cg^-vchNNbudMZbJ7n-uZgR4H=33A=ziYD8EWWzg$+e}#AhK2A@q4E!iD9FeH9 zbBc5lqeA<5CJd7c1WK`#EAh^MSp|!pT^ZNQs^SSM;=G09uh*SXQP4QR;G--`XG@x0 zm>0m*F~gC}FT}OqfmBvzD52gJ6d$7A2Tlg;lt;bot*SXebY{^@c3xZG@1tyNJT&zc zz>xvJ!2YEo4;8HmZ(jXL%=x0*6AYXznrvM4=;}+v@_hD>d^3S=8W_}4BxOMbB!E2T zJ5D8Qn((;|CUZ;NV4)u4b}Pof&`24{Yzv6AvfmCjhJPW;tYEBd8ogTA(^FM*ESf># z)(Zk&)xE~l?)6#CV8^o#;)>5D?G;m(eWU2U)g)ge#@hQ25UnhZw>c9+5GkvR0^@_* zY5rHOcC1&zM*r;q59xDZPzK6S=T)~0FAh)VbGwI6CKjgX{poMU(^WSv{gKo(0e%Mf zl6mUu#(kZjlFv6&H+@|MAr?*AN%+0H(*#ksrRDy%()!D>qTZ(!G5U#9dD;e)hUrFa zW*y~XZ+V*^XLmGXGOb6D8*pf~0qu%lWE2?4N?{=Fjg0gY{FY=Fw}R+{8?|Fk5-z5s z-zGM?lpfx#)+kUlW?UPs8LPPlH`}D39B{ivOM}ZJzEp^f^! zD$&ADFHnuSw~b@id^PT&@K}BkuEK+t!VV!8An%aU1!3baUt!L6jhdJgyUijy`gM5k z{@pdNJ?^f%PB1NPL3wS%#Kx95bzv6xT|H&jWjBI>G7jkODHNaXKKAOqc?AvAu$K!H z9!Es(Q&b3~M;Wpv3>8BoAggHJ2b^O{Q6>@NU-&?5d{-x*eq@Yly&;>CYJ6m!0^fB& z@ul8Q1)Wq}2E_mFm%>2lO$l!yA^Cmb@v^pve!|s;Yv~Kz!iO)s(xJ#mO!IKu_JzIn zn6pg^(Sye3qc0BZw)@Rr!#_OtSJZ8D>}#{H!Oe?Ag>1|Ecb(pke2L$RC%C8rPAJZQ{S;SYc*IT87mgo8B&aT|5 zzgVI1Yj*n!ZQ3I)vLiPTE7eR-BM4YnKIB4ldLDz2MYUh}r?MXu6~b&IUnrC}Pqqvs zX;eY(r`Y#3x((hf<)#s7C(ng4t+yYJPEAQpFkoOyf{gPQ3iGdE+&K1WTGBvUFfpTO ze=2lcMyMwsS+`MDj9ezbfGA1!- zJ>9KfphRsrQ+y%o%uCJ7>P9(`ieXmINo7?tH;xM-e54)G9{GKw9$?{&-rrR_q^wK-t#WMud6h#Eq?bw7o9FurWr=pj#m-i z8K<20;cMJ1Jps-)bC-cnmH9ImH%7o`vBRR1Q%2juZYmTQ5;qIYoKQ#L4-um#p}GI^ zTSpXlbqDoQKDNve6g?ob^t2VrTRCVux`|9lC3 zA^I>WmXx|B7{V|bytZe%WQTv6t4-OKl{834J=vZqsm8mH=_9Xni1z~IScoTV+FVt8 z+}gGe56pPewTeeuR@|u?iEGAox<9fKk4xJ@U4LeniyEeCNB(h2wvs z(54GYifLz8fqVO(O^85;msN-U3&|)Z z{w)!pcuK3Cc6Zj=SBmp?5?tq051r$8fbp34RVWTOtv#~-QYLDBbE=t|9X2?DNnTe4 zHmh__o__V8b%%#cwx#{aAlL)zhf`POO`RMKU6#{Y^+P1qv3Nh@QBYL0ubHIa%s<(c zM<_$_J2T7PSo1_)(*cBH42OWi&Y+CB+qbGps_0dw8%12wD^U7M?z0-$IM3Y~{d?=w zZsa%D*i*)XUlZ#S+4-VBE2gK10aDtw&k2fHH{MS#%((VPh-ORR^_YBKa?z}@lYHV6 z)fS;*wQ<@5DpDbng}Xls;E$xSi=ceD$PEeF>S{!Gwn=twKvDUwc0EG$EMO>-C@f;8O7(U^Gd2qY^q= z+0eQZ#kNNS%AfKu4&`-}PJO4Q?F925cK6^Z|v zIz8dd!LA7F^0jKQ>*yyi(H$z9 zwoPVO585S=CR2<&?LYl--fKp(>r7RGD9PJH2>NX8r}Vac6ko*b{>4enCYfb3>U&uT zyH!`Q%182V$Soq{WNeSQA_Iy^h^^y8joshBmsvoOZ%b)eBC%{XopR~Na{4hr*A|Wz zN_X&$2fZ2A5`mOhxLB?}VS%@v3p`>e#)RI$%_kw+)t~5}e9i%y3em;i7=1l9YEI4K}>RjYZ z`cnyQw6X{KlZJQsaxthQ^}|SG%jX6P_Lx`-CTr{dc6`EM#8o`DLE-KE(ykxVAb1ms z5W2=%>^cA1hbrBfYWvCf;BE7*_UzzO=8*jiR$I|-`kdeQ7}MBT!dh*wvrsp4_@G_Y zTUXr)Y8DOG>*FAEAF&nRj}ZZtUtdu9nb$%FlypLFbTQxddXM&URondAPZjcoAbXiT z^U0MGI-SCoo&K z$lM$G)!vpi#XA{#g%C2WH!Y;huc$X3PC{jLKr)vqThpr_{dn*p3+m-h7Z5!j*(oBb zka$c{*CNCEo{x_H)n2Tk>2E9o$EUUAU^(wqa$EyCmQJtzwpP1ZX{MxLJA8XIMu*zc zu78C=^yMn_Ei!gLhVb<|Qa9k{0)fJp)}!MQJNN@w#nrg-^al6vp=Tv_Ikbk1Rq;sJ z&h8y={Q*x6mxT??hAuD31!-GTIh$~PK-@-WUqFel6REEGQMJ>RgZLt{p8d4^eLp~6 zP2ls$(Z0@kCsaClMv-<%cR`#xYm{lsW;o46WYTB^tO!MN-F=%8az%Wz&V<-A z;3-Ige~DypOi6-^X0%eB`_1?Lp(|44X1T?q+n{$pd!h$4E8T|03m(YN1Ze&u8$SWs zzv-m@&3yG+3cvq%wTTBMoN1B6;V6yTALaM&eo;6FibIbS;+>xss`7~H`QMtke`qSP zul|?3dVcGN!hdc3fNc$AB%EGVcbif%9f9IiQ(b^4Z&qWn!(kI7_e5G_)_!8~tWn^g zSHiJOYDnsw(JT8QtoI8g`rng-cedY)Visgz$bf9|9#ZV^;7=Xb^v=NF1Mwneb-1r# zc4|Ly^3rpH67=u2iD=MvmkG` zT$G+p9rotPKIK6JTbCEIP%`v$ekS$zx+yd55l5Z>nkS%|Rwx+h;YLjp4Kk=yCff;H z7j`PlkSIeNsM>)v$852RcyB7OfP;XJTWt_Q@`p;{hA?g(;(Ujgras|>bMxw%0<5$= z@N9Pj<*(LM7i1(0hrTia_CGMJbzO)@<(?0?5K;L=C;k-P6Z}3CzO`7!;A|UVv{G~O z88BZ!RhB}D*ewHH(*w;_VWeco3h~$nM9@PmnIqsXzHmm=Ed!uEp}7ja%_Cj-2TE#c ztL@y~oEOTTWoF#T;ggh<3t1MV%cxu%^Lcn+O9U$Ro*ZcK9-){k?&y_=I^F^T_~v{6 z4qN*|3E)_SrV>R~y1U~&qRF5FR;6e)tU4xS6IvX2G5p+V@xqtcaA}i}AMAY9dyMF; z;aDYNKGz&P{Gu^g#0Qe5Sf9g%%X4;ieIh&tnpEy`JQ9gssd!ocFz_&7*aZ)YY;!$5 z>WrCEy>Jq;*srLupuwXUG5^402)v5#fmd_tBarCxsQ*UE8R_2BQupn12vmmv6H9fi z_$_J-ZR-#Vl;dGt{5i@;l1R{f?JbhHn;z@WoIqa*JhL}IcZZ6l9zI9xR;ETt8R=FJ5dv)d-D<+f<2SMrgvU zyOQ}$$S!A_Vu=}5pm`Qpd7$hfC?P?@INCDAXl+Yzlnv9yg&V+25TIaaE-9sYmww=J z;FTf^6L7InGj%q7jgGC(vzR*aF41h-LH*fy_69|kT{9phL<<9NK2Wg2i`<~<6Gf(# zw6(j{-EJXgqkffMbk|H0Gq|h4tyPJ1wMK_{0G_R-@IY?epBLo|h=2ZLqDA|LXo9H0gk$$QTakH^b&ES3uQ+W>%-kL0 zjgZBpX0}$~yltj;{C?su*lz*1r$QzlAvI|qZ&b70nH>WCie}v35#@=VkyHR9*9ur+ zbSmKf_N@-AH;~UZzEs_{2u%+fR+56UWdsTU-4x%xk;E|u@2Fg7uunbT;tdG;z=4;d zMX2oT{ux@aLLQTVRkVj&6Wl|6N;>^2`jQ3Qj1Iyw8Hjq$cdv1c< z6+uD~qCWJG&Rd+Pz1EC*7$qh>wYliN4m7v5YlZs|PL|^xz<7_JtPE_x_O}r&%G1S# zw!1jAkCesNTMUxEhX%Cy%At&M5gYExn;4;c;QPQ~D9g1-ScC+av&&V9K=i%9#%=;# z4Z@=-66@WhPP+jk<; znwht-c2a#5*!2^XGOEJtmEyicc4Y=_PO~>~bJqtxXdxor*%-a9*{WqvyvtOUEZ5$) zFN!}8fUgo}LI{RN2}xZm13L$x!a0qUj)OC+gS`91no5QB(nb@a>;P_g0gJ|d3r8K` zTMdT^Lf8wo2nrNP=X!Oj3OkchSZVP%Z7R27DvnC_4y83`fmZ2UtKFCE789AJB|}dG z!uOj<#HB)3&#hnGR)C+C&jMeg{t=glKI$-6 W^F`T6=}P~jy|IaV8s0nhkN*YA2sab} literal 0 HcmV?d00001 From fb170439e86220903960269abd4dbcbe31a4ff6f Mon Sep 17 00:00:00 2001 From: eqy Date: Wed, 14 Aug 2024 11:59:59 -0700 Subject: [PATCH 06/53] Update half.h (#1709) --- include/cutlass/half.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/cutlass/half.h b/include/cutlass/half.h index e5dcd71eb9..a0f398284a 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -602,30 +602,39 @@ struct numeric_limits { static int const digits = 10; /// Least positive value + CUTLASS_HOST_DEVICE static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } /// Minimum finite value + CUTLASS_HOST_DEVICE static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } /// Maximum finite value + CUTLASS_HOST_DEVICE static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } /// Returns smallest finite value + CUTLASS_HOST_DEVICE static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } /// Returns maximum rounding error + CUTLASS_HOST_DEVICE static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } /// Returns positive infinity value + CUTLASS_HOST_DEVICE static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } /// Returns quiet NaN value + CUTLASS_HOST_DEVICE static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } /// Returns signaling NaN value + CUTLASS_HOST_DEVICE static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } /// Returns smallest positive subnormal value + CUTLASS_HOST_DEVICE static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } }; } // namespace std From 8d8cfdf37560b6d799685daf80bee18c96114732 Mon Sep 17 00:00:00 2001 From: Haicheng Wu Date: Wed, 14 Aug 2024 21:12:44 -0700 Subject: [PATCH 07/53] update 3.5.1 readme/changelog --- CHANGELOG.md | 7 +- README.md | 3 +- .../gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu | 8 +- ...16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu | 61 +-------- ...16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu | 76 +---------- ...32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu | 65 +-------- ...32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu | 128 +----------------- ..._s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu | 84 +----------- ..._s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu | 84 +----------- test/unit/gemm/device/multistage_testbed.h | 8 +- .../syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu | 2 +- .../syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu | 2 +- .../syrk_tf32n_f32t_tensor_op_f32_sm80.cu | 2 +- .../syrk_tf32t_f32t_tensor_op_f32_sm80.cu | 2 +- 14 files changed, 34 insertions(+), 498 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a418edc011..c784107be9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,15 @@ - Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). - [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). -- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) leveraging 2:4 structured sparsity and [support for LLM friendly tile sizes](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu). +- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence: + + [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411). + + [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456). - [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. - Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). - Support for residual add (beta != 0) in convolution kernels. +- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. - A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). - [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). - Better support for MSVC as a host compiler. diff --git a/README.md b/README.md index 3c5700a027..1426e8a42e 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,11 @@ CUTLASS 3.5.1 is an update to CUTLASS adding: - Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). - [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). -- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) leveraging 2:4 structured sparsity and [support for LLM friendly tile sizes](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu). +- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inference. - [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. - Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). - Support for residual add (beta != 0) in convolution kernels. +- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. - A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). - [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). - Better support for MSVC as a host compiler. diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu index 5186b818af..794ce6fc73 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu @@ -182,7 +182,7 @@ TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f32, 32x256x64_32x64x64) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 5>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } @@ -220,7 +220,7 @@ TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f32, 256x32x64_64x32x64) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 5>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } @@ -257,7 +257,7 @@ TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f32, 16x256x64_16x64x64) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 5>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } @@ -295,7 +295,7 @@ TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f32, 256x16x64_64x16x64) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 5>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllGemm()); } diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu index f2ddaa5a94..81aa8016e4 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu @@ -320,24 +320,6 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x32x64_64x32x64) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x32x128_64x32x128) { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, - cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 32, 128>, - cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x128x128_32x32x128) { using ElementOutput = float; using ElementAccumulator = float; @@ -351,45 +333,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x128x128_32x32x128) cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900) -TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x256x64_32x64x64) { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, - cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 64>, - cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} -#endif - -TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x256x128_32x64x128) { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, - cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 128>, - cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -461,10 +405,11 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x16x128_64x16x128) cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } + //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu index 3b1e85e750..01e191ba64 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu @@ -279,45 +279,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x128x128_32x32x128) cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900) -TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x256x64_32x64x64) { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 64>, - cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} -#endif - -TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x256x128_32x64x128) { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 128>, - cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -376,24 +338,6 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x32x64_64x32x64) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x32x128_64x32x128) { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 32, 128>, - cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x16x64_32x16x64) { using ElementOutput = float; using ElementAccumulator = float; @@ -448,24 +392,6 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x16x64_64x16x64) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x16x128_64x16x128) { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 16, 128>, - cutlass::gemm::GemmShape<64, 16, 128>, cutlass::gemm::GemmShape<16, 8, 32>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu index 8ce0f138f2..452250fb3d 100644 --- a/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu @@ -449,75 +449,12 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x128x64_32x32x64) { ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 6 + 3 >; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900) -TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x256x32_32x64x32) { - - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - float, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 32>, - cutlass::gemm::GemmShape<32, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 6 - >; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} -#endif - -TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x256x64_32x64x64) { - - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - float, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 64>, - cutlass::gemm::GemmShape<32, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 6 - >; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} ///////////////////////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu index 382b0b2261..5fd7d7d475 100644 --- a/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu @@ -449,71 +449,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x128x64_32x32x64) { ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 6 - >; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900) -TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x256x32_32x64x32) { - - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 32>, - cutlass::gemm::GemmShape<32, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 6 - >; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} -#endif - -TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x256x64_32x64x64) { - - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 64>, - cutlass::gemm::GemmShape<32, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 6 + 3 >; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); @@ -612,37 +548,6 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x32x32_64x32x32) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x32x64_64x32x64) { - - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 32, 64>, - cutlass::gemm::GemmShape<64, 32, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 6 - >; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x16x32_32x16x32) { using ElementOutput = float; @@ -736,37 +641,6 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x16x32_64x16x32) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x16x64_64x16x64) { - - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::ColumnMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 16, 64>, - cutlass::gemm::GemmShape<64, 16, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 6 - >; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - ///////////////////////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu index bea3e946b8..73d45d5489 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu @@ -275,7 +275,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x128x512_32x32x512) { cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -313,26 +313,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x256x256_32x64x256) { cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - -TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x256x512_32x64x512) { - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementCompute = int32_t; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 512>, - cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 128>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -351,26 +332,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 16x128x512_16x32x512) { cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - -TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 16x256x512_16x64x512) { - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementCompute = int32_t; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<16, 256, 512>, - cutlass::gemm::GemmShape<16, 64, 512>, cutlass::gemm::GemmShape<16, 8, 128>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -408,7 +370,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x32x512_32x32x512) { cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -432,25 +394,6 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x32x256_64x32x256) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x32x512_64x32x512) { - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementCompute = int32_t; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 32, 512>, - cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x16x256_32x16x256) { using ElementOutput = int32_t; using ElementAccumulator = int32_t; @@ -508,25 +451,6 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x16x256_16x64x256) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x16x512_16x64x512) { - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementCompute = int32_t; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 16, 512>, - cutlass::gemm::GemmShape<64, 16, 512>, cutlass::gemm::GemmShape<16, 8, 128>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - //////////////////////////////////////////////////////////////////////////////// #endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu index 4cb879b635..96b56322cf 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu @@ -294,7 +294,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x128x256_32x32x256) { cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -313,26 +313,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x256x128_32x64x128) { cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - -TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x256x256_32x64x256) { - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementCompute = int32_t; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - int8_t, cutlass::layout::RowMajor, int8_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 256, 256>, - cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -351,26 +332,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 16x128x256_16x32x256) { cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - -TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 16x128x256_32x32x256) { - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementCompute = int32_t; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - int8_t, cutlass::layout::RowMajor, int8_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<16, 128, 256>, - cutlass::gemm::GemmShape<16, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -408,7 +370,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x32x256_32x32x256) { cutlass::epilogue::thread::LinearCombinationClamp< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } @@ -432,25 +394,6 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x32x128_64x32x128) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x32x256_64x32x256) { - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementCompute = int32_t; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - int8_t, cutlass::layout::RowMajor, int8_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 32, 256>, - cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x16x128_32x16x128) { using ElementOutput = int32_t; using ElementAccumulator = int32_t; @@ -508,25 +451,6 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x16x128_64x16x128) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } -TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x16x256_64x16x256) { - using ElementOutput = int32_t; - using ElementAccumulator = int32_t; - using ElementCompute = int32_t; - - using Gemm = cutlass::gemm::device::GemmSparseUniversal< - int8_t, cutlass::layout::RowMajor, int8_t, - cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 16, 256>, - cutlass::gemm::GemmShape<64, 16, 256>, cutlass::gemm::GemmShape<16, 8, 64>, - cutlass::epilogue::thread::LinearCombinationClamp< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementCompute>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - - EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -} - //////////////////////////////////////////////////////////////////////////////// #endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/multistage_testbed.h b/test/unit/gemm/device/multistage_testbed.h index 2d6f37652b..2fc718648f 100644 --- a/test/unit/gemm/device/multistage_testbed.h +++ b/test/unit/gemm/device/multistage_testbed.h @@ -141,10 +141,10 @@ struct MultistageTestbed { ElementCompute alpha = ElementCompute(1), ElementCompute beta = ElementCompute(0)) { - // Waives test if CUDA device is insufficient - if (!sufficient()) { - return true; - } + // Waives test if CUDA device is insufficient + if (!sufficient()) { + return true; + } // // Allocate the GEMM workspace diff --git a/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu b/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu index b39d9945c8..58486bf096 100644 --- a/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu +++ b/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu @@ -78,7 +78,7 @@ TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x256x32_64x64x32) { ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4 + 3 >; EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); diff --git a/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu b/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu index 57b16cfba0..8b3b22ac67 100644 --- a/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu +++ b/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu @@ -78,7 +78,7 @@ TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x256x32_64x64x32) { ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4 + 3 >; EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); diff --git a/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu index ccd31b8662..6d33a3620b 100644 --- a/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu @@ -78,7 +78,7 @@ TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4 + 3 >; EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); diff --git a/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu index 49463ef433..530c4ba7e1 100644 --- a/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu @@ -78,7 +78,7 @@ TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 4 + 3 >; EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); From b0296bf682375b56401c17901693218c54b11b8b Mon Sep 17 00:00:00 2001 From: Haicheng Wu Date: Thu, 15 Aug 2024 21:06:01 -0700 Subject: [PATCH 08/53] fix uint128 --- include/cutlass/uint128.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index 0a41e95cd3..096d90f1fb 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -194,7 +194,7 @@ struct alignas(16) uint128_t uint64_t remainder{0}; #if defined(CUTLASS_UINT128_NATIVE) remainder = uint64_t(native % divisor); -#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) +#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) && ! defined (__CUDA_ARCH__) // implemented using MSVC's arithmetic intrinsics (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else From 3f084f7f3c07d18066fb971823009aad9e00f77d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= <115986737+alexsamardzic@users.noreply.github.com> Date: Fri, 16 Aug 2024 06:59:29 +0200 Subject: [PATCH 09/53] Add couple configs into generator.py for mixed input MM (#1350) * Add couple configs into generator.py for mixed input MM * change one unit test name; reenable 128x32 in the profiler * Added U8/BF16 tests. --------- Co-authored-by: Haicheng Wu Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com> --- python/cutlass_library/gemm_operation.py | 30 ++---- python/cutlass_library/generator.py | 47 ++++++--- test/unit/gemm/device/CMakeLists.txt | 26 ++++- ...s8n_f32t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...8n_bf16t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...u8n_f32t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...s8n_f16t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...s8n_f32t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...u8n_f16t_mixed_input_tensor_op_f16_sm80.cu | 2 +- ...u8n_f16t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...u8n_f32t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...16n_f32t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...16n_f16t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...16n_f32t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...6n_bf16t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...16n_f32t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...16n_f16t_mixed_input_tensor_op_f16_sm80.cu | 2 +- ...16n_f16t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ ...16n_f32t_mixed_input_tensor_op_f32_sm80.cu | 97 +++++++++++++++++++ test/unit/gemm/device/testbed_universal.h | 9 +- .../src/reference/gemm_fp_mixed_input.cu | 74 +++++++++++--- 21 files changed, 1487 insertions(+), 61 deletions(-) create mode 100644 test/unit/gemm/device/gemm_universal_bf16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_bf16t_u8n_bf16t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_bf16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_f16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_f16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_u8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_u8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_u8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 5e015492f4..6159148da4 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -178,30 +178,16 @@ def extended_name(self): if self.is_complex(): extended_name = "${core_name}" else: - # e.g. f16_f16_f32_void_f32 kernel - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}_${element_a}" - if self.is_mixed_input(): - extended_name += "_${element_b}" - - # e.g. f32_f32_f32_void_f32 kernel - elif self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element == self.tile_description.math_instruction.element_accumulator: - extended_name = "${element_c}_${core_name}" - if self.is_mixed_input(): - extended_name += "_${element_b}" - - # e.g. f16_f16_f32_f32_f32 kernel - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: - extended_name = "${core_name}_${element_a}" - if self.is_mixed_input(): - extended_name += "_${element_b}" - - # e.g. f32_f32_f32_f32_f32 kernel + if self.is_mixed_input(): + extended_name = "${core_name}_${element_a}_${element_b}" + if self.C.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_" + extended_name else: extended_name = "${core_name}" + if self.C.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_" + extended_name + if self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name += "_${element_a}" extended_name = SubstituteTemplate(extended_name, { 'element_a': DataTypeNames[self.A.element], diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 1f2eb86ecc..7a5a47b196 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -2575,17 +2575,17 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): math_instructions = [ MathInstruction( \ [16, 8, 16], \ - DataType.s8, DataType.f16, DataType.f16, \ + DataType.s8, DataType.f16, DataType.f32, \ OpcodeClass.TensorOp, \ MathOperation.multiply_add_mixed_input_upcast), MathInstruction( \ [16, 8, 16], \ - DataType.s8, DataType.f16, DataType.f32, \ + DataType.u8, DataType.f16, DataType.f32, \ OpcodeClass.TensorOp, \ MathOperation.multiply_add_mixed_input_upcast), MathInstruction( \ [16, 8, 16], \ - DataType.u8, DataType.f16, DataType.f32, \ + DataType.s8, DataType.bf16, DataType.f32, \ OpcodeClass.TensorOp, \ MathOperation.multiply_add_mixed_input_upcast), MathInstruction( \ @@ -2595,7 +2595,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): MathOperation.multiply_add_mixed_input_upcast), MathInstruction( \ [16, 8, 16], \ - DataType.s8, DataType.bf16, DataType.f32, \ + DataType.s8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.f16, DataType.f16, \ OpcodeClass.TensorOp, \ MathOperation.multiply_add_mixed_input_upcast), ] @@ -2637,7 +2642,7 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: + if math_inst.element_b != math_inst.element_accumulator: data_type_mixed = [ math_inst.element_a, @@ -2649,10 +2654,10 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) - for op in operations: - if (DataTypeSize[op.C.element] == 16) and \ - (op.tile_description.threadblock_shape[1] <= 32): - op.C.alignment = 4 + for op in operations: + if (DataTypeSize[op.C.element] == 16) and \ + (op.tile_description.threadblock_shape[1] <= 32): + op.C.alignment = 4 # def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): @@ -2672,12 +2677,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): MathOperation.multiply_add_mixed_input_upcast), MathInstruction( \ [16, 8, 16], \ - DataType.bf16, DataType.s8, DataType.f32, \ + DataType.f16, DataType.u8, DataType.f32, \ OpcodeClass.TensorOp, \ MathOperation.multiply_add_mixed_input_upcast), MathInstruction( \ [16, 8, 16], \ - DataType.f16, DataType.u8, DataType.f32, \ + DataType.bf16, DataType.s8, DataType.f32, \ OpcodeClass.TensorOp, \ MathOperation.multiply_add_mixed_input_upcast), MathInstruction( \ @@ -2685,6 +2690,16 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): DataType.bf16, DataType.u8, DataType.f32, \ OpcodeClass.TensorOp, \ MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), ] min_cc = 80 @@ -2728,7 +2743,7 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): ] # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. - CreateGemmOperator(manifest, layouts, tile_descriptions, \ + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) @@ -2741,12 +2756,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): math_inst.element_accumulator, ] - operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) - for op in operations: - if op.tile_description.threadblock_shape[1] <= 32: - op.C.alignment = 4 + for op in operations: + if op.tile_description.threadblock_shape[1] <= 32: + op.C.alignment = 4 # def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index ce7b02606d..b5afa433e9 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -251,14 +251,32 @@ cutlass_test_unit_add_executable( BATCH_SIZE 4 # Upcast on Operand A - gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu - gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_s8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_u8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu + + gemm_universal_s8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_u8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_u8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu + + gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu # Upcast on Operand B - gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu - gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_f16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_f16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f32_sm80.cu + + gemm_universal_bf16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_bf16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu + gemm_universal_bf16t_u8n_bf16t_mixed_input_tensor_op_f32_sm80.cu + + gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/gemm/device/gemm_universal_bf16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_bf16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..bbeb9a1610 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_bf16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_bf16t_s8n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = int8_t; + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_bf16t_u8n_bf16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_bf16t_u8n_bf16t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..614b3e5f6c --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_bf16t_u8n_bf16t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_bf16t_u8n_bf16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = uint8_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_bf16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_bf16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..93c59c5178 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_bf16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_bf16t_u8n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = uint8_t; + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..86d1da774e --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_f16t_s8n_f16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = cutlass::half_t; + using ElementB = int8_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_f16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..20da1150d0 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_f16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_f16t_s8n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = cutlass::half_t; + using ElementB = int8_t; + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu index d53711cb44..9b105c9eeb 100644 --- a/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -56,7 +56,7 @@ //////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { +TEST(SM80_Device_GemmUniversal_f16t_u8n_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { using ElementA = cutlass::half_t; using ElementB = uint8_t; diff --git a/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..b26b213638 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_f16t_u8n_f16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = cutlass::half_t; + using ElementB = uint8_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_f16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..926a88e8e4 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_f16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_f16t_u8n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = cutlass::half_t; + using ElementB = uint8_t; + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..8572e55981 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..eb4e293a39 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_f16n_f16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..064c9b048d --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_f16n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::half_t; + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_u8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_u8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..b2a29022c0 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_u8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_u8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = uint8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_u8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_u8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..d6b65974c8 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_u8t_bf16n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_u8t_bf16n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = uint8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu index 20a6b5a58c..41657c2fca 100644 --- a/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -56,7 +56,7 @@ //////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { +TEST(SM80_Device_GemmUniversal_u8t_f16n_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { using ElementA = uint8_t; using ElementB = cutlass::half_t; diff --git a/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..b2b3cd3a21 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_u8t_f16n_f16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = uint8_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_u8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_u8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..358c109e86 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_u8t_f16n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_u8t_f16n_f32t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = uint8_t; + using ElementB = cutlass::half_t; + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h index 8dc92db0e5..a5f9e819b6 100644 --- a/test/unit/gemm/device/testbed_universal.h +++ b/test/unit/gemm/device/testbed_universal.h @@ -112,8 +112,13 @@ struct TestbedUniversal { scope_max = is_unsigned_int ? 2 : 1; scope_min = is_unsigned_int ? 0 : -1; } else if (bits_output == 16) { - scope_max = is_unsigned_int ? 10 : 5; - scope_min = is_unsigned_int ? 0 : -5; + constexpr auto u8_bf16 = + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value) || + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value); + scope_max = is_unsigned_int ? 10 : (u8_bf16 ? 3 : 5); + scope_min = is_unsigned_int ? 0 : (u8_bf16 ? -3 : -5); } else { scope_max = 8; scope_min = -8; diff --git a/tools/library/src/reference/gemm_fp_mixed_input.cu b/tools/library/src/reference/gemm_fp_mixed_input.cu index 7937552cd9..ef4913d08c 100644 --- a/tools/library/src/reference/gemm_fp_mixed_input.cu +++ b/tools/library/src/reference/gemm_fp_mixed_input.cu @@ -50,15 +50,34 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { make_gemm_real_canonical_layouts< int8_t, half_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, half_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, half_t, - half_t + half_t, + float >(manifest); make_gemm_real_canonical_layouts< uint8_t, half_t, half_t, + float + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + half_t, half_t, half_t >(manifest); @@ -67,47 +86,62 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { uint8_t, half_t, half_t, + half_t + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + int8_t, float, float >(manifest); make_gemm_real_canonical_layouts< - int8_t, - half_t, half_t, + uint8_t, float, float >(manifest); + make_gemm_real_canonical_layouts< + half_t, + int8_t, + half_t, + half_t + >(manifest); + make_gemm_real_canonical_layouts< half_t, uint8_t, half_t, - float, - float + half_t >(manifest); make_gemm_real_canonical_layouts< half_t, int8_t, half_t, - float, float >(manifest); - // bfloat16_t mixed with 8-bit integer input make_gemm_real_canonical_layouts< + half_t, uint8_t, - bfloat16_t, - bfloat16_t, - float, + half_t, float >(manifest); + // bfloat16_t mixed with 8-bit integer input make_gemm_real_canonical_layouts< int8_t, bfloat16_t, float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + bfloat16_t, float, float >(manifest); @@ -116,14 +150,19 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { int8_t, bfloat16_t, bfloat16_t, - float, float >(manifest); make_gemm_real_canonical_layouts< - bfloat16_t, uint8_t, - float, + bfloat16_t, + bfloat16_t, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + int8_t, float, float >(manifest); @@ -131,7 +170,6 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { make_gemm_real_canonical_layouts< bfloat16_t, uint8_t, - bfloat16_t, float, float >(manifest); @@ -140,7 +178,13 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { bfloat16_t, int8_t, bfloat16_t, - float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + uint8_t, + bfloat16_t, float >(manifest); } From 4dbf5dbed2331b948b75a3dbeaf760d76b3b5964 Mon Sep 17 00:00:00 2001 From: shunfan-shao <79347016+shunfan-shao@users.noreply.github.com> Date: Mon, 19 Aug 2024 10:26:09 -0700 Subject: [PATCH 10/53] Use CUDA runtime API to retrieve function pointer to driver API (#1700) * Query pfn to driver api * use default for older toolkits --------- Co-authored-by: shunfans --- CMakeLists.txt | 1 + include/cute/atom/copy_traits_sm90_im2col.hpp | 4 +- include/cute/atom/copy_traits_sm90_tma.hpp | 3 +- include/cutlass/cuda_host_adapter.hpp | 72 +++++++++++++++++++ 4 files changed, 78 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ac67eb8653..4e1ffd7531 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,6 +234,7 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.") +set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API) ################################################################################ # diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index f6c9e258eb..ad4f8675b5 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -40,6 +40,8 @@ #include "cute/algorithm/prefetch.hpp" #include "cutlass/fast_math.h" +#include "cutlass/cuda_host_adapter.hpp" + namespace cute { @@ -450,7 +452,7 @@ make_im2col_tma_copy_desc( CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_); CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle)); - CUresult encode_result = cuTensorMapEncodeIm2col( + CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)( &tma_desc, tma_format, num_total_modes, diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 950855a1a2..2238c41897 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -41,6 +41,7 @@ #include #include +#include namespace cute { @@ -983,7 +984,7 @@ make_tma_copy_desc(Tensor const& gtensor, // The origin // TMA smem swizzle type CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); - CUresult result = cuTensorMapEncodeTiled( + CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &tma_desc, tma_format, tma_dim, diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index 28f5ae0e8d..f9ff723ce1 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -82,6 +82,78 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__CUDACC_RTC__) + +#include +#include + +#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok + +#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + return func(args...); \ + } + +#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + cudaDriverEntryPointQueryResult cuda_status; \ + void* pfn = nullptr; \ + cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \ + CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ + &pfn, ver, \ + cudaEnableDefault, \ + &cuda_status); \ + if (cuda_status != cudaDriverEntryPointSuccess || \ + cuda_err != cudaSuccess) { \ + return CUDA_ERROR_UNKNOWN; \ + } \ + return reinterpret_cast(pfn)(args...); \ + } + +#else + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + cudaDriverEntryPointQueryResult cuda_status; \ + void* pfn = nullptr; \ + cudaError_t cuda_err = cudaGetDriverEntryPoint( \ + CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ + &pfn, \ + cudaEnableDefault, \ + &cuda_status); \ + if (cuda_status != cudaDriverEntryPointSuccess || \ + cuda_err != cudaSuccess) { \ + return CUDA_ERROR_UNKNOWN; \ + } \ + return reinterpret_cast(pfn)(args...); \ + } + +#endif // (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) + +#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#if (__CUDACC_VER_MAJOR__ >= 12) +CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000); +CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000); +#endif + +#undef CUTLASS_CUDA_DRIVER_STRINGIFY + +#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func + +#endif // !defined(__CUDACC_RTC__) + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter /// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute /// is not introduced. From f7b19de32c5d1f3cedfc735c2849f12b537522ee Mon Sep 17 00:00:00 2001 From: Shreya Gaur <48754356+Shreya-gaur@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:21:42 -0400 Subject: [PATCH 11/53] minor fix for a double quote in CMakeLists.txt (#1727) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e1ffd7531..7419bdf5e5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,7 +234,7 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.") -set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API) +set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API.") ################################################################################ # From e1976daacc7b030ba672217eb5d96f5a663df4ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= <115986737+alexsamardzic@users.noreply.github.com> Date: Fri, 30 Aug 2024 05:11:06 +0200 Subject: [PATCH 12/53] Add support for mixed 4-bit/8-bit data types GEMM (#1413) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add support for mixed 4-bit/8-bit data types GEMM * fix ( and ) --------- Co-authored-by: Aleksandar Samardžić Co-authored-by: Haicheng Wu --- .../gemm/device/default_gemm_configuration.h | 54 ++++++ .../gemm/warp/default_mma_tensor_op_sm80.h | 71 +++++++- .../gemm/warp/mma_mixed_input_tensor_op.h | 14 +- include/cutlass/numeric_conversion.h | 80 +++++++++ python/cutlass_library/generator.py | 163 ++++++++++++++++++ test/unit/core/fast_numeric_conversion.cu | 24 ++- test/unit/gemm/device/CMakeLists.txt | 6 + ...s8n_s32t_mixed_input_tensor_op_s32_sm80.cu | 95 ++++++++++ ..._s8n_s8t_mixed_input_tensor_op_s32_sm80.cu | 95 ++++++++++ ...s4n_s32t_mixed_input_tensor_op_s32_sm80.cu | 95 ++++++++++ ..._s4n_s8t_mixed_input_tensor_op_s32_sm80.cu | 95 ++++++++++ test/unit/gemm/warp/gemm_mixed_input_sm80.cu | 48 ++++++ tools/library/CMakeLists.txt | 1 + .../src/reference/gemm_int_mixed_input.cu | 130 ++++++++++++++ .../initialize_reference_operations.cu | 3 + 15 files changed, 960 insertions(+), 14 deletions(-) create mode 100644 test/unit/gemm/device/gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu create mode 100644 tools/library/src/reference/gemm_int_mixed_input.cu diff --git a/include/cutlass/gemm/device/default_gemm_configuration.h b/include/cutlass/gemm/device/default_gemm_configuration.h index 4197a6b080..c9e7cc76d1 100644 --- a/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/include/cutlass/gemm/device/default_gemm_configuration.h @@ -793,6 +793,60 @@ struct DefaultGemmConfigurationSm89F8 { using Operator = arch::OpMultiplyAdd; }; +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for SM89 fe4m3 x fe4m3 template struct DefaultGemmConfiguration< diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h index d7e3232c81..2c851f469a 100644 --- a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h +++ b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h @@ -268,7 +268,7 @@ struct DefaultMmaTensorOp< "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type"); // Data type used for internal computation - use the wider of the two data types for mma.sync operands - using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)), + using ElementOperand = typename platform::conditional<(sizeof_bits::value > sizeof_bits::value), ElementA, ElementB>::type; // Operand datatypes in the internal MMA instruction - use the wider of the two data types @@ -294,6 +294,75 @@ struct DefaultMmaTensorOp< Policy, PartitionsK, AccumulatorsInRowMajor>; }; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial Specialization - inputs are mixed types - uses wider datatype internally. +/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32) +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Element type of A matrix + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Element type of B matrix + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp< + WarpShape_, + GemmShape<16, 8, 32>, // InstructionShape + ElementA, // Element type of A matrix in Global Memory + LayoutA, // Layout of A matrix in Global Memory + ElementB, // Element type of B matrix in Global Memory + LayoutB, // Layout of B matrix in Global Memory + ElementC, // Element type of C matrix in Global Memory + LayoutC, // Layout of C matrix in Global Memory + arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype + PartitionsK, AccumulatorsInRowMajor> { + + + // Check if the ElementA and ElementB are of different data types + static_assert(!platform::is_same::value, + "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type"); + + // Data type used for internal computation - use the wider of the two data types for mma.sync operands + using ElementOperand = typename platform::conditional<(sizeof_bits::value > sizeof_bits::value), + ElementA, ElementB>::type; + + // Operand datatypes in the internal MMA instruction - use the wider of the two data types + using MmaElementA = ElementOperand; + using MmaElementB = ElementOperand; + using MmaElementC = ElementC; + + // Uses + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + GemmShape<16, 8, 32>, + 32, + MmaElementA, cutlass::layout::RowMajor, + MmaElementB, cutlass::layout::ColumnMajor, + MmaElementC, cutlass::layout::RowMajor, + arch::OpMultiplyAddSaturate + >, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaMixedInputTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp diff --git a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h index 14d8f33455..0b37ad24c6 100644 --- a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h @@ -104,6 +104,7 @@ struct FragmentShuffler { //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8) +/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4) /// for operand A multiplicand going through upcasting. template < /// Element type for the operand in registers for the mma.sync @@ -122,8 +123,10 @@ struct FragmentShuffler ::value == 16) && - (sizeof_bits::value == 8)>::type> { + typename platform::enable_if<((sizeof_bits::value == 16) && + (sizeof_bits::value == 8)) || + ((sizeof_bits::value == 8) && + (sizeof_bits::value == 4))>::type> { public: using ElementMma = ElementMma_; using ElementLoad = ElementLoad_; @@ -187,6 +190,7 @@ struct FragmentShuffler ::value == 16) && - (sizeof_bits::value == 8)>::type> { + typename platform::enable_if<((sizeof_bits::value == 16) && + (sizeof_bits::value == 8)) || + ((sizeof_bits::value == 8) && + (sizeof_bits::value == 4))>::type> { public: using ElementMma = ElementMma_; using ElementLoad = ElementLoad_; diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 2e74afa8e4..1701b4ac8d 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -2771,6 +2771,86 @@ struct NumericArrayConverter { } }; +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + unsigned const& storage = reinterpret_cast(source); + unsigned out[2]; + + asm volatile( + "{ .reg .u32 tmp0, tmp1, tmp2;" + "shl.b32 tmp0, %2, 4;" + "and.b32 tmp0, tmp0, 0xf0f0f0f0;" + "prmt.b32 tmp1, tmp0, tmp0, 0xba98;" + "and.b32 tmp1, tmp1, 0xf0f0f0f0;" + "shr.u32 tmp0, tmp0, 4;" + "or.b32 tmp2, tmp0, tmp1;" + "and.b32 tmp0, %2, 0xf0f0f0f0;" + "prmt.b32 tmp1, tmp0, tmp0, 0xba98;" + "and.b32 tmp1, tmp1, 0xf0f0f0f0;" + "shr.u32 tmp0, tmp0, 4;" + "or.b32 tmp0, tmp0, tmp1;" + "prmt.b32 %0, tmp2, tmp0, 0x5140;" + "prmt.b32 %1, tmp2, tmp0, 0x7362;" + "}" + : "=r"(out[0]), "=r"(out[1]) + : "r"(storage)); + + return reinterpret_cast(out); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + static_assert(!(N % 8), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + #endif // Conditional guards to enable partial specialization for packed integers namespace detail { diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 7a5a47b196..c736551432 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -2855,6 +2855,167 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): op.C.alignment = 8 # +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s4, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[32, 16, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[32, 16, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + DataType.f32 + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand B + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 32, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[16, 32, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 # def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version): @@ -4699,6 +4860,8 @@ def GenerateSM80(manifest, cuda_version): GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version) GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version) GenerateSM80_TensorOp_16832_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version) GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version) GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version) GenerateSM80_TensorOp_16864_TN(manifest, cuda_version) diff --git a/test/unit/core/fast_numeric_conversion.cu b/test/unit/core/fast_numeric_conversion.cu index 0d6e2401a9..99aab24581 100644 --- a/test/unit/core/fast_numeric_conversion.cu +++ b/test/unit/core/fast_numeric_conversion.cu @@ -69,7 +69,7 @@ void run_test_integer_range_limited() { cutlass::HostTensor source({1, kN}); for (int i = 0; i < kN; ++i) { - source.host_data()[i] = Source(i % 4); + source.host_view().at({0, i}) = Source(i % 4); } source.sync_device(); @@ -82,7 +82,7 @@ void run_test_integer_range_limited() { destination.sync_host(); for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); + EXPECT_TRUE(float(destination.host_view().at({0, i})) == float(source.host_view().at({0, i}))); } } @@ -97,13 +97,12 @@ void run_test_integer_range_all() { cutlass::HostTensor destination({1, kN}); cutlass::HostTensor source({1, kN}); - int const kIntSourceMin = std::numeric_limits::min(); - int const kIntSourceMax = std::numeric_limits::max(); + int const kIntSourceMin = cutlass::platform::numeric_limits::lowest(); + int const kIntSourceMax = cutlass::platform::numeric_limits::max(); int const kIntRange = kIntSourceMax - kIntSourceMin + 1; for (int i = 0; i < kN; ++i) { - source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange)); - + source.host_view().at({0, i}) = Source(kIntSourceMin + (i % kIntRange)); } source.sync_device(); @@ -118,7 +117,7 @@ void run_test_integer_range_all() { // Verify conversion bool passed = true; for (int i = 0; i < kN; ++i) { - if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) { + if(!(float(destination.host_view().at({0, i})) == float(source.host_view().at({0, i})))) { passed = false; break; } @@ -128,8 +127,8 @@ void run_test_integer_range_all() { // Print out results for the failed conversion. if (!passed) { for (int i = 0; i < kN; ++i) { - std::cout << "source(" << float(source.host_data()[i]) << ") -> " - << "destination ("<< float(destination.host_data()[i]) << ")" << std::endl; + std::cout << "source(" << float(source.host_view().at({0, i})) << ") -> " + << "destination ("<< float(destination.host_view().at({0, i})) << ")" << std::endl; } } std::flush(std::cout); @@ -188,3 +187,10 @@ TEST(FastNumericConversion, s8_to_bf16_array) { using Destination = cutlass::bfloat16_t; test::core::kernel::run_test_integer_range_all(); } + +TEST(FastNumericConversion, s4_to_s8_array) { + int const kN = 16; + using Source = cutlass::int4b_t; + using Destination = int8_t; + test::core::kernel::run_test_integer_range_all(); +} diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index b5afa433e9..a70ce542d0 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -264,6 +264,9 @@ cutlass_test_unit_add_executable( gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu + gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu + # Upcast on Operand B gemm_universal_f16t_s8n_f32t_mixed_input_tensor_op_f32_sm80.cu gemm_universal_f16t_u8n_f32t_mixed_input_tensor_op_f32_sm80.cu @@ -277,6 +280,9 @@ cutlass_test_unit_add_executable( gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu + + gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu + gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/gemm/device/gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..421ea0c0b2 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s4t_s8n_s32t_mixed_input_tensor_op_s32_sm80.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s4t_s8n_s32t_mixed_input_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementA = cutlass::int4b_t; + using ElementB = int8_t; + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 32, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..685092fb84 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s4t_s8n_s8t_mixed_input_tensor_op_s32_sm80.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s4t_s8n_s8t_mixed_input_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementA = cutlass::int4b_t; + using ElementB = int8_t; + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 32, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..b28cee62c0 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_s4n_s32t_mixed_input_tensor_op_s32_sm80.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_s4n_s32t_mixed_input_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 32, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..89a52b3e80 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_s8t_s4n_s8t_mixed_input_tensor_op_s32, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 32, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu index eb7d8023d0..db5b178f38 100644 --- a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu +++ b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu @@ -324,4 +324,52 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1 .run(); } +//////////////////////////////////////////////////////////////////////////////// +/// S32 <= I4 * I8 + S32 (Upcast on Operand A) +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i4_i8, 64x64x64_64x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using ElementA = cutlass::int4b_t; + using ElementB = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// S32 <= I8 * I4 + S32 (Upcast on Operand B) +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_i4, 64x64x64_64x64x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type; + + test::gemm::warp::TransformTestbed >() + .run(); +} + #endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index f8a28fe6b9..9b54f50817 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -234,6 +234,7 @@ cutlass_add_cutlass_library( src/reference/gemm_fp32out.cu src/reference/gemm_fp_other.cu src/reference/gemm_fp_mixed_input.cu + src/reference/gemm_int_mixed_input.cu src/reference/initialize_reference_operations.cu # cutlass reduction instances in cutlass library diff --git a/tools/library/src/reference/gemm_int_mixed_input.cu b/tools/library/src/reference/gemm_int_mixed_input.cu new file mode 100644 index 0000000000..8d6072e3ef --- /dev/null +++ b/tools/library/src/reference/gemm_int_mixed_input.cu @@ -0,0 +1,130 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_int_mixed_input(Manifest &manifest) { + // 4-bit integer mixed with 8-bit integer input + make_gemm_real_canonical_layouts< + int4b_t, + int8_t, + int32_t, + int32_t + >(manifest); + + make_gemm_real_canonical_layouts< + int4b_t, + int8_t, + int8_t, + int32_t, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int4b_t, + int8_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int4b_t, + int8_t, + int8_t, + float, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int4b_t, + int32_t, + int32_t + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int4b_t, + int8_t, + int32_t, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int4b_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int4b_t, + int8_t, + float, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu index 16679a27d8..59872b9742 100644 --- a/tools/library/src/reference/initialize_reference_operations.cu +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -57,6 +57,7 @@ void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest); void initialize_gemm_reference_operations_fp32out(Manifest &manifest); void initialize_gemm_reference_operations_fp_other(Manifest &manifest); void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest); +void initialize_gemm_reference_operations_int_mixed_input(Manifest &manifest); void initialize_conv2d_reference_operations(Manifest &manifest); void initialize_conv3d_reference_operations(Manifest &manifest); @@ -85,6 +86,8 @@ void initialize_reference_operations(Manifest &manifest) { initialize_gemm_reference_operations_fp_other(manifest); initialize_gemm_reference_operations_fp_mixed_input(manifest); + initialize_gemm_reference_operations_int_mixed_input(manifest); + } /////////////////////////////////////////////////////////////////////////////////////////////////// From 6c3044136b6462d0ff028ece1c1a83bb90a5b3aa Mon Sep 17 00:00:00 2001 From: Alchan Kim Date: Thu, 5 Sep 2024 03:52:11 +0900 Subject: [PATCH 13/53] Update barrier.h (#1782) --- include/cutlass/barrier.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index 94f300add9..6f2373b6df 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -277,7 +277,7 @@ struct NamedBarrierManager { CUTLASS_DEVICE static void check_barrier_in_range([[maybe_unused]] uint32_t idx) { - assert((idx >= MaxNumNamedBarriers) && "Index exceeds barrier count"); + assert((idx < MaxNumNamedBarriers) && "Index exceeds barrier count"); } template From 7369adcaca5b9db84ec04b6f52a8d1f8ef968e8d Mon Sep 17 00:00:00 2001 From: JiayuSun <470323747@qq.com> Date: Thu, 5 Sep 2024 03:11:24 +0800 Subject: [PATCH 14/53] Add Sm90LinCombPerColBias (#1774) Co-authored-by: Jiayu Sun --- .../cutlass/epilogue/fusion/operations.hpp | 17 ++++ .../sm90_callbacks_tma_warpspecialized.hpp | 84 +++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index a483b1ba94..a01288778c 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -158,6 +158,23 @@ struct LinCombPerRowBiasEltAct static constexpr bool IsEltActSupported = true; }; +// D = alpha * acc + beta * C + per-column bias +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerColBiasSupported = true; +}; + // D = activation(alpha * acc + beta * C + per-row bias) // aux = alpha * acc + beta * C + per-row bias template< diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 1de0a28e0f..06f315779b 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -333,6 +333,90 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = alpha * acc + beta * C + per-column bias +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_0,_1,int>, AlignmentBias> // bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBias, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { + using Impl = Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + using Operation = fusion::LinCombPerColBias< + ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideBias = Stride<_0,_1,int>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = activation(alpha * acc + beta * C + per-row bias) template< class CtaTileShapeMNK, From 06e337758dd2c01b06930c543fcb0dfb7781cb93 Mon Sep 17 00:00:00 2001 From: Saagar Jha Date: Thu, 5 Sep 2024 14:14:15 -0700 Subject: [PATCH 15/53] Remove extraneous comma in declaration (#1776) --- .../epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 06f315779b..3f43f60d6e 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -993,7 +993,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = Sm90EVT, // activation(Z) Sm90EVT, // Aux = Z // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias, + Sm90ScaledLinCombPerRowBias > > >, From 82f5075946e2569589439d500733b700a3141374 Mon Sep 17 00:00:00 2001 From: Gabriel Wu Date: Fri, 6 Sep 2024 11:24:10 +0800 Subject: [PATCH 16/53] set_slice3x3 -> set_slice_3x3 (#1784) --- include/cutlass/matrix.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/cutlass/matrix.h b/include/cutlass/matrix.h index ab32597e39..5d8ccb3c1c 100644 --- a/include/cutlass/matrix.h +++ b/include/cutlass/matrix.h @@ -7825,7 +7825,7 @@ struct Matrix { Matrix m; - m.set_slice3x3({ + m.set_slice_3x3({ c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos @@ -7845,7 +7845,7 @@ struct Matrix { Matrix m = Matrix::identity(); - m.set_slice3x3({ + m.set_slice_3x3({ Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c @@ -14005,7 +14005,7 @@ struct Matrix { Matrix m; - m.set_slice3x3({ + m.set_slice_3x3({ c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos @@ -14025,7 +14025,7 @@ struct Matrix { Matrix m = Matrix::identity(); - m.set_slice3x3({ + m.set_slice_3x3({ Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c From 323c8170bffdd11d774437b450e42d842e203517 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 5 Sep 2024 20:25:03 -0700 Subject: [PATCH 17/53] Support ComputeFn where output type differs from input type (#1771) This is useful for e.g. function taking in 2 float inputs and turn them to complex --- .../sm90_visitor_compute_tma_warpspecialized.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 0b12badc7d..8f5ceb5489 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -181,14 +181,20 @@ struct Sm90Compute { }, [&] (auto&&... cvt_frg_inputs) { using ComputeOutput = ComputeFn>; - using ConvertOutput = NumericArrayConverter; ComputeOutput compute_output{}; - ConvertOutput convert_output{}; if constexpr (cute::is_same_v) { + using ElementComputeOutput = + typename cute::remove_cvref_t::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; return convert_output(compute_output(cvt_frg_inputs...)); } else { + using ElementComputeOutput = + typename cute::remove_cvref_t::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; return convert_output(compute_output(cvt_frg_inputs..., params)); } } From 21d0534167d71c806af7f88d70ba024cb85f34c3 Mon Sep 17 00:00:00 2001 From: Sean Xiaowen Zhang Date: Mon, 9 Sep 2024 11:05:27 -0700 Subject: [PATCH 18/53] fix assertion (#1790) --- include/cute/arch/copy_sm80.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cute/arch/copy_sm80.hpp b/include/cute/arch/copy_sm80.hpp index 43e3d0d728..e04181bfe9 100644 --- a/include/cute/arch/copy_sm80.hpp +++ b/include/cute/arch/copy_sm80.hpp @@ -77,7 +77,7 @@ struct SM80_CP_ASYNC_CACHEGLOBAL using DRegisters = TD[1]; static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); - static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + static_assert(sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); CUTE_HOST_DEVICE static void copy(TS const& gmem_src, @@ -134,7 +134,7 @@ struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL using DRegisters = TD[1]; static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); - static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + static_assert(sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); CUTE_HOST_DEVICE static void copy(TS const& gmem_src, From dbdae514e03f83968f8b7dd4fb064071b9bfbdd1 Mon Sep 17 00:00:00 2001 From: Junkai-Wu Date: Wed, 11 Sep 2024 12:07:31 +0800 Subject: [PATCH 19/53] Support for TMA Epilogue for Group Gemm and add pingpong ptr array & Group Gemm (#1795) --- .../56_hopper_ptr_array_batched_gemm.cu | 120 ++- .../57_hopper_grouped_gemm.cu | 105 +- .../57_hopper_grouped_gemm/CMakeLists.txt | 4 +- include/cute/util/type_traits.hpp | 14 + .../collective/builders/sm90_builder.inl | 33 +- .../cutlass/epilogue/collective/detail.hpp | 92 +- ...m90_epilogue_array_tma_warpspecialized.hpp | 360 +++++-- include/cutlass/epilogue/dispatch_policy.hpp | 20 +- .../cutlass/epilogue/fusion/operations.hpp | 1 + .../sm90_callbacks_tma_warpspecialized.hpp | 83 ++ .../sm90_visitor_load_tma_warpspecialized.hpp | 167 ++- .../collective/builders/sm90_gmma_builder.inl | 60 +- ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 62 +- include/cutlass/gemm/dispatch_policy.hpp | 7 +- .../cutlass/gemm/kernel/gemm_universal.hpp | 1 + ..._array_tma_warpspecialized_cooperative.hpp | 175 ++-- ...emm_array_tma_warpspecialized_pingpong.hpp | 949 ++++++++++++++++++ test/unit/gemm/device/CMakeLists.txt | 2 + .../gemm/device/gemm_testbed_3x_ptr_array.hpp | 14 +- ...mm_f16_f16_f16_tensor_op_f32_group_gemm.cu | 66 +- ...6_f16_tensor_op_f32_group_gemm_pingpong.cu | 184 ++++ ...emm_f16_f16_f16_tensor_op_f32_ptr_array.cu | 5 +- ...16_f16_tensor_op_f32_ptr_array_pingpong.cu | 182 ++++ 23 files changed, 2359 insertions(+), 347 deletions(-) create mode 100644 include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp create mode 100644 test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu create mode 100644 test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu index 7a191ce2d8..5181678ca7 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -95,40 +95,66 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // M using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size -using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch -using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch - -using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementAccumulator, - ElementC, LayoutC, AlignmentC, - ElementC, LayoutC, AlignmentC, - EpilogueSchedule - >::CollectiveOp; - -using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule - >::CollectiveOp; - -using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cutlass::gemm::ArrayProblemShape>, - CollectiveMainloop, - CollectiveEpilogue ->; - -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Different configs for pingpong/cooperative +struct CooperativeConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_256,_128,_64>; + using ClusterShape = Shape<_1,_2,_1>; +}; + +struct PingpongConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = Shape<_64,_128,_64>; + using ClusterShape = Shape<_1,_1,_1>; +}; + +template +struct GemmGivenSchedule { + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GemmKernel = GemmGivenSchedule::GemmKernel; +using Gemm = GemmGivenSchedule::Gemm; + +using GemmKernelPingpong = GemmGivenSchedule::GemmKernel; +using GemmPingpong = GemmGivenSchedule::Gemm; + // Reference device GEMM implementation type using DeviceGemmReference = cutlass::reference::device::Gemm< @@ -261,14 +287,14 @@ bool initialize_block( int bits_input = cutlass::sizeof_bits::value; if (bits_input == 1) { - scope_max = 2; - scope_min = 0; + scope_max = static_cast(2); + scope_min = static_cast(0); } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; + scope_max = static_cast(2); + scope_min = static_cast(-2); } else { - scope_max = 8; - scope_min = -8; + scope_max = static_cast(8); + scope_min = static_cast(-8); } cutlass::reference::device::BlockFillRandomUniform( @@ -351,7 +377,8 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) +template +typename GemmT::Arguments args_from_options(const Options &options) { cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with multiple GPUs and wish @@ -359,7 +386,7 @@ typename Gemm::Arguments args_from_options(const Options &options) hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Gemm::Arguments arguments{ + typename GemmT::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kArray, {{options.m, options.n, options.k, options.l}}, {ptr_A.get(), stride_A, ptr_B.get(), stride_B}, @@ -405,20 +432,20 @@ bool verify(const Options &options) { } /// Execute a given example GEMM computation -template +template int run(Options &options) { allocate(options); initialize(options); // Instantiate CUTLASS kernel depending on templates - Gemm gemm; + GemmT gemm; // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm - auto arguments = args_from_options(options); + auto arguments = args_from_options(options); // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); + size_t workspace_size = GemmT::get_workspace_size(arguments); // Allocate workspace memory cutlass::device_memory::allocation workspace(workspace_size); @@ -510,7 +537,10 @@ int main(int argc, char const **args) { // #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cout << "\n*** Cooperative schedule ***" << std::endl; run(options); + std::cout << "\n*** Pingpong schedule ***" << std::endl; + run(options); #endif return 0; diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index f94679568a..a26d904dcc 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -117,20 +117,39 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // A using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size -using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch -using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch -using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +// Different configs for pingpong/cooperative +struct CooperativeConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_256,_128,_128>; + using ClusterShape = Shape<_2,_2,_1>; +}; + +struct PingpongConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_2,_1,_1>; +}; + +template +struct GemmGivenSchedule { + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutC *, AlignmentC, ElementC, LayoutC *, AlignmentC, - EpilogueSchedule + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -144,13 +163,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder KernelSchedule >::CollectiveOp; -using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveMainloop, - CollectiveEpilogue ->; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GemmKernel = GemmGivenSchedule::GemmKernel; +using Gemm = GemmGivenSchedule::Gemm; + +using GemmKernelPingpong = GemmGivenSchedule::GemmKernel; +using GemmPingpong = GemmGivenSchedule::Gemm; // Reference device GEMM implementation type using DeviceGemmReference = cutlass::reference::device::Gemm< @@ -271,10 +297,10 @@ struct Options { int n = cmd_line_n; int k = cmd_line_k; if (m < 1) { - m = ((rand() % 512) + 1); + m = alignment * ((rand() % 64) + 1); } if (n < 1) { - n = ((rand() % 512) + 1); + n = alignment * ((rand() % 64) + 1); } if (k < 1) { k = alignment * ((rand() % 64) + 1); @@ -521,7 +547,8 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true) +template +typename GemmT::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true) { cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with multiple GPUs and wish @@ -529,33 +556,49 @@ typename Gemm::Arguments args_from_options(const Options &options, bool host_pro hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Gemm::EpilogueOutputOp::Params params; + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. - params = typename Gemm::EpilogueOutputOp::Params( - ElementAccumulator(options.alpha), ElementAccumulator(options.beta)); + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; } else { // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. - params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get()); + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; } - typename Gemm::Arguments arguments; if (host_problem_shapes_available) { - arguments = typename Gemm::Arguments { + arguments = typename GemmT::Arguments { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; } else { - arguments = typename Gemm::Arguments { + arguments = typename GemmT::Arguments { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info }; } @@ -605,20 +648,20 @@ bool verify(const Options &options) { } /// Execute a given example GEMM computation -template +template int run(Options &options, bool host_problem_shapes_available = true) { allocate(options); initialize(options); // Instantiate CUTLASS kernel depending on templates - Gemm gemm; + GemmT gemm; // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm - auto arguments = args_from_options(options, host_problem_shapes_available); + auto arguments = args_from_options(options, host_problem_shapes_available); // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); + size_t workspace_size = GemmT::get_workspace_size(arguments); // Allocate workspace memory cutlass::device_memory::allocation workspace(workspace_size); @@ -713,8 +756,14 @@ int main(int argc, char const **args) { // #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cout << "\n*** Cooperative schedule ***" << std::endl; run(options); + std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***" << std::endl; run(options, false /*host_problem_shapes_available*/); + std::cout << "\n*** Pingpong schedule ***" << std::endl; + run(options); + std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***" << std::endl; + run(options, false /*host_problem_shapes_available*/); #endif return 0; diff --git a/examples/57_hopper_grouped_gemm/CMakeLists.txt b/examples/57_hopper_grouped_gemm/CMakeLists.txt index 2c3ff3a496..1dadbfa813 100644 --- a/examples/57_hopper_grouped_gemm/CMakeLists.txt +++ b/examples/57_hopper_grouped_gemm/CMakeLists.txt @@ -32,10 +32,10 @@ set(TEST_RANDOM --iterations=0) # Random problem sizes set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes -set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes -set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index f12cdb594f..f0eb55116d 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -274,4 +274,18 @@ struct conditional_template { using type = False; }; +// +// is_any_of +// + +/// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us +template +struct is_any_of { + constexpr static bool value = (... || CUTE_STL_NAMESPACE::is_same_v); +}; + +/// Is true if and only if T is same as (is_same_v) at least one of the types in Us +template +inline constexpr bool is_any_of_v = is_any_of::value; + } // end namespace cute diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 2ca62c9794..90a600028c 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -71,14 +71,18 @@ sm90_get_tma_dispatch_policy() { // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); // TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs - constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_tma_ptr_array_v; + constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_ptr_array_tma_v; constexpr int StagesD = cute::min(EpiTiles, 2); constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) : cute::min(EpiTiles, 4); - return cute::conditional_t, - Sm90PtrArrayTmaWarpSpecialized, - Sm90TmaWarpSpecialized>{}; + if constexpr (detail::sm90_is_ptr_array_tma_v) { + return Sm90PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm90TmaWarpSpecialized{}; + } } // Returns the smem layout atom to be used for C or D matrix @@ -255,6 +259,9 @@ struct Sm90TmaBuilderImpl { using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; + using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; + using CopyOpS2G = cute::conditional_t, SM90_TMA_STORE_IM2COL, SM90_TMA_STORE @@ -267,17 +274,11 @@ struct Sm90TmaBuilderImpl { // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; - using FusionDispatchPolicy = Sm90TmaWarpSpecialized; - // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination using FusionCallbacks = typename CallbacksBuilder< - FusionDispatchPolicy, + DispatchPolicy, FusionOpOrCallbacks, TileShape_MNK, EpilogueTile_MN, @@ -294,11 +295,11 @@ struct Sm90TmaBuilderImpl { GmemStrideTypeD, FusionCallbacks, CopyOpG2S, - decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), - decltype(detail::sm90_get_smem_load_op_for_source()), + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), CopyOpS2G, - decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), - decltype(detail::sm90_get_smem_store_op_for_accumulator()), + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()), CopyAtomC >; }; @@ -483,7 +484,7 @@ struct CollectiveBuilder< FusionOperation, cute::enable_if_t || cute::is_same_v || - cute::is_same_v >> { + detail::sm90_is_ptr_array_tma_v>> { private: using ElementD = cute::conditional_t, fusion::get_element_aux_t, ElementD_>; diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index a6e5e2f4d6..b96b13fecc 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -71,6 +71,62 @@ is_im2col() { || cute::is_same_v>; } +template +struct sm90_is_ptr_array_tma : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_v = sm90_is_ptr_array_tma::value; + +template +struct sm90_is_ptr_array_tma_cooperative : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma_cooperative : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_cooperative_v = sm90_is_ptr_array_tma_cooperative::value; + +template +struct sm90_is_ptr_array_tma_pingpong : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma_pingpong : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_pingpong_v = sm90_is_ptr_array_tma_pingpong::value; + +template +struct sm90_is_ptr_array_tma_dispatch_policy : cute::false_type {}; + +template< + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups +> +struct sm90_is_ptr_array_tma_dispatch_policy< + Sm90PtrArrayTmaWarpSpecialized> + : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy::value; + using cutlass::atomic_maximum; template @@ -79,14 +135,11 @@ static constexpr int elements_per_access_v = cutlass::sizeof_bits::val template static constexpr bool sm90_is_cooperative_v = cute::is_base_of_v || - cute::is_base_of_v; - -template -static constexpr bool sm90_is_tma_ptr_array_v = - cute::is_base_of_v; + sm90_is_ptr_array_tma_cooperative_v; template static constexpr bool sm90_is_warp_specialized_v = + (!sm90_is_ptr_array_tma_cooperative_v && sm90_is_ptr_array_tma_v) || cute::is_base_of_v; template @@ -199,7 +252,11 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { } CUTLASS_DEVICE auto - load_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count, [[maybe_unused]] int32_t const sm_idx) const { + load_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx) { return cute::make_tuple(nullptr); } @@ -243,7 +300,7 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { [[maybe_unused]] TensorStorage& shared_tensors, [[maybe_unused]] TensorMapC const& load_tensormap, [[maybe_unused]] int subtile_idx=-1, - [[maybe_unused]] bool return_prior_state = false) + [[maybe_unused]] bool wait = false) { return load_pipe_producer_state; } @@ -257,8 +314,12 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { } CUTLASS_DEVICE auto - store_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count, - [[maybe_unused]] int32_t const sm_idx) const { + store_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx, + [[maybe_unused]] int32_t warp_group_idx) { return cute::make_tuple(nullptr); } @@ -369,22 +430,25 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { // Dummy methods to perform different parts of TMA/Tensormap modifications - template + template CUTLASS_DEVICE void tensormaps_perform_update( - [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] int32_t next_batch) { } + [[maybe_unused]] ProblemShapeMNKL problem_shape, + [[maybe_unused]] int32_t next_batch, + [[maybe_unused]] int32_t warp_group_idx) { } template CUTLASS_DEVICE void tensormaps_cp_fence_release( - [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] uint32_t lane_predicate) { } + [[maybe_unused]] uint32_t lane_predicate, + [[maybe_unused]] int32_t warp_group_idx) { } template CUTLASS_DEVICE diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 87e628879c..87b6786721 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -44,9 +44,10 @@ #include "cutlass/detail/collective.hpp" #include "cutlass/detail/layout.hpp" #include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" #include "cute/tensor.hpp" -#include "cutlass/cuda_host_adapter.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -62,6 +63,7 @@ template < int FragmentSize_, bool ReuseSmemC_, bool DelayTmaStore_, + int NumEpilogueWarpGroups_, class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) class ElementC_, @@ -78,7 +80,13 @@ template < class CopyAtomC_ > class CollectiveEpilogue< - Sm90PtrArrayTmaWarpSpecialized, + Sm90PtrArrayTmaWarpSpecialized, CtaTileMNK_, EpilogueTile_, ElementC_, @@ -98,7 +106,13 @@ class CollectiveEpilogue< // // Type Aliases // - using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized; + using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized; using CtaTileMNK = CtaTileMNK_; using EpilogueTile = EpilogueTile_; using FusionCallbacks = FusionCallbacks_; @@ -201,6 +215,8 @@ class CollectiveEpilogue< (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; constexpr static bool RequiresTransactionBytes = true; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; + // TMA pipeline for storing D using StorePipeline = cute::conditional_t, @@ -219,7 +235,7 @@ class CollectiveEpilogue< struct TensorMapStorage : cute::aligned_struct<128> { cute::TmaDescriptor smem_tensormap_C; - cute::TmaDescriptor smem_tensormap_D; + cute::array smem_tensormap_D; } tensormaps; using PipelineStorage = typename LoadPipeline::SharedStorage; @@ -229,6 +245,8 @@ class CollectiveEpilogue< using TensorMapStorage = typename SharedStorage::TensorMapStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + // Host side epilogue arguments struct Arguments { typename FusionCallbacks::Arguments thread{}; @@ -261,7 +279,9 @@ class CollectiveEpilogue< TMA_D tma_store_d; cute::TmaDescriptor* tensormaps; ElementC const** ptr_C; + StrideC dC; ElementD** ptr_D; + StrideD dD; uint32_t tma_transaction_bytes = TmaTransactionBytes; }; @@ -275,36 +295,57 @@ class CollectiveEpilogue< ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1); - auto [M, N, K, mock_L] = problem_shape_MNKL; - // Manage batches/groups through pointers to input matricies - mock_L = 1; + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_L = get<3>(init_shape); static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D"); + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_c = InternalStrideC{}; + stride_d = InternalStrideD{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); + init_M = get<0>(problem_shape_MNKL); + init_N = get<1>(problem_shape_MNKL); + init_L = get<3>(problem_shape_MNKL); + + stride_c = args.dC; + stride_d = args.dD; + } + uint32_t transaction_bytes = TmaTransactionBytes; typename Params::TMA_C tma_load_c = {}; if constexpr (is_source_supported) { ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); - Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dC, _0{}))); - tma_load_c = make_tma_copy_C_sm90( + Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); + tma_load_c = make_tma_copy( CopyOpG2S{}, tensor_c, take<0,2>(SmemLayoutC{}), - EpilogueTile{}); + EpilogueTile{}, + _1{}); } typename Params::TMA_D tma_store_d; if constexpr (is_destination_supported) { ElementD const* ptr_D_first_batch = reinterpret_cast(args.ptr_D); - Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dD, _0{}))); - tma_store_d = make_tma_copy_C_sm90( + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); + tma_store_d = make_tma_copy( CopyOpS2G{}, tensor_d, take<0,2>(SmemLayoutD{}), - EpilogueTile{}); + EpilogueTile{}, + _1{}); } auto fusion_workspace = static_cast(workspace); @@ -318,7 +359,9 @@ class CollectiveEpilogue< tma_store_d, tma_descriptor_workspace, args.ptr_C, + args.dC, args.ptr_D, + args.dD, transaction_bytes, }; } @@ -326,10 +369,11 @@ class CollectiveEpilogue< template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { - constexpr uint32_t NumInputTensors = cute::is_void_v ? 1 : 2; + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + auto descriptors_shape = cute::make_shape(sm_count, Int{}); constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies - return (NumInputTensors * SizeOfCuTensorMap * sm_count) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); + return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); } template @@ -342,30 +386,40 @@ class CollectiveEpilogue< template static bool can_implement( - ProblemShape const& problem_shape, + ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1); - auto [M,N,K,L] = problem_shape_MNKL; bool implementable = true; - if constexpr (is_destination_supported) { - constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); - constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; - implementable = cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); - } + bool fusion_implementable = true; + + if (problem_shape.is_host_problem_shape_available()) { + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; - if constexpr (not cute::is_void_v) { - constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); - constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + } + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + } + else { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); } if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); } - bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); - if (!fusion_implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); } @@ -414,10 +468,14 @@ class CollectiveEpilogue< } CUTLASS_DEVICE auto - load_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { + load_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { // Initialize tma for loading constexpr bool IsLoad = true; - auto load_tensormaps = tensormaps_init(params, sm_count, sm_idx); + auto load_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0); return load_tensormaps; } @@ -426,7 +484,8 @@ class CollectiveEpilogue< class TileShapeMNK, class TileCoordMNKL, class TiledMma, - class TensorMapC + class TensorMapC, + __CUTE_REQUIRES(std::is_pointer_v) > CUTLASS_DEVICE auto load( @@ -440,7 +499,7 @@ class CollectiveEpilogue< TensorStorage& shared_tensors, TensorMapC const& load_tensormap, int subtile_idx=-1, - bool return_prior_state = false) { + bool wait_until_load_finishes = false) { using namespace cute; // Indexing variables @@ -478,17 +537,21 @@ class CollectiveEpilogue< auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + LoadPipelineState last_load_producer_state = load_pipe_producer_state; + // Predication for TMA load (one thread issues TMA load) bool issue_tma_load = cute::elect_one_sync(); // Acquire the lock for the first stage - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); load_pipeline.producer_acquire(load_pipe_producer_state); + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); // Pre-loop fusion callback entry point pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); - auto prior_state = load_pipe_producer_state; + LoadPipelineState prior_state = load_pipe_producer_state; + + bool did_load = false; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { @@ -506,15 +569,18 @@ class CollectiveEpilogue< pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); // Execute the TMA load for C if needed - if (issue_tma_load && is_C_load_needed) { - copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask), - bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); - load_pipeline.producer_expect_transaction(load_pipe_producer_state); + if (is_C_load_needed) { + if (issue_tma_load) { + copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + last_load_producer_state = load_pipe_producer_state; + did_load = true; } // Commit TMA loads for this stage and release the lock load_pipeline.producer_commit(load_pipe_producer_state); - prior_state = load_pipe_producer_state; ++load_pipe_producer_state; } } @@ -522,17 +588,24 @@ class CollectiveEpilogue< // Post-loop fusion callback entry point pld_callbacks.end(); - if (not return_prior_state) { - return load_pipe_producer_state; - } else { - return prior_state; + if (wait_until_load_finishes && did_load) { + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = + {last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()}; + load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); } + + return load_pipe_producer_state; } CUTLASS_DEVICE auto load_tail( LoadPipeline load_pipeline, LoadPipelineState load_pipe_producer_state) { + + if (!fusion_callbacks.is_producer_load_needed()) { + return load_pipe_producer_state; + } + bool issue_tma_load = cute::elect_one_sync(); if (issue_tma_load) { load_pipeline.producer_tail(load_pipe_producer_state); @@ -564,6 +637,7 @@ class CollectiveEpilogue< TensorStorage& shared_tensors, TensorMapD const& store_tensormap, int subtile_idx=-1) { + using namespace cute; using ElementAccumulator = typename AccEngine::value_type; using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; @@ -869,11 +943,22 @@ class CollectiveEpilogue< } CUTLASS_DEVICE auto - store_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { - // Initialize tma - constexpr bool IsLoad = false; - auto store_tensormaps = tensormaps_init(params, sm_count, sm_idx); - return store_tensormaps; + store_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup; + // Since only one warp issues TMA store, we only need that one warp to initialize tensormaps + if (warp_idx_in_warp_group == 0) { + // Initialize tma + constexpr bool IsLoad = false; + auto store_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx); + return store_tensormaps; + } + TmaDescriptor* null_tma_desc = nullptr; + return cute::make_tuple(null_tma_desc); } // @@ -882,89 +967,145 @@ class CollectiveEpilogue< template CUTLASS_DEVICE auto - tensormaps_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { - cute::TmaDescriptor* tma_desc = nullptr; - cute::TmaDescriptor* gmem_tensormap = params.tensormaps; + tensormaps_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + Layout desc_layout = make_layout(make_shape(sm_count, Int{})); + + Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors) + if constexpr (IsLoad) { if (not cute::is_void_v) { - tma_desc = &gmem_tensormap[sm_idx]; + constexpr int C_tensormap_index = NumEpilogueWarpGroups; + Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { - // Bringing tensormaps from params to gmem for modification later - Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor gC_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{}); - copy(recast(pC_tensormap), recast(gC_tensormap)); + // Bringing tensormaps from params to smem for modification later + copy(recast(pC_tensormap), recast(sC_tensormap)); } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index)); } - } else { - int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; - tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; + TmaDescriptor* null_tma_desc = nullptr; + return cute::make_tuple(null_tma_desc); + } + else { + Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_D[warp_group_idx]), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { - // Bringing tensormaps from params to gmem for modification later - Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor gD_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{}); - copy(recast(pD_tensormap), recast(gD_tensormap)); + // Bringing tensormaps from params to smem for modification later + copy(recast(pD_tensormap), recast(sD_tensormap)); } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx)); } - - return cute::make_tuple(tma_desc); } - // Bringing tensormaps to smem (to be done by single thread) + // Replace address for the global tensor (to be done by single thread) template CUTLASS_DEVICE void - tensormaps_fetch_to_smem( - TensorMapStorage& shared_tensormap, - cute::TmaDescriptor const* tensormap) const { + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& params, + int32_t next_batch, + int32_t warp_group_idx) { + // Replacing global_address for the next batch if constexpr (IsLoad) { - if (not cute::is_void_v) { - Tensor gC_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{}); - Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C), Int<1>{}, Int<1>{}); - copy(recast(gC_tensormap), recast(sC_tensormap)); + if constexpr (is_source_supported) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, + params.ptr_C[next_batch]); } - } else { - Tensor gD_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{}); - Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D), Int<1>{}, Int<1>{}); - copy(recast(gD_tensormap), recast(sD_tensormap)); } - cp_async_fence(); - cp_async_wait<0>(); + else if constexpr (is_destination_supported) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], + params.ptr_D[next_batch]); + } } - // Replace address for the global tensor (to be done by single thread) - template + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template CUTLASS_DEVICE void - tensormaps_replace_global_address( - TensorMapStorage& shared_tensormap, + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, Params const& params, - int32_t next_batch) { - // Replacing global_address for the next batch + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl, + int32_t warp_group_idx) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + // Only consider dimensions and strides that we need to recalculate and replace for each group + constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N + static_assert(TensorRank == Int<3>{}, + "Descriptor modification for global dims & strides expects rank as 3."); + + cute::array prob_shape = {1,1,1}; + cute::array prob_stride = {0,0,0}; + if constexpr (IsLoad) { - if (not cute::is_void_v) { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C, - params.ptr_C[next_batch]); + if constexpr (is_source_supported) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); } - } else { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D, - params.ptr_D[next_batch]); + } + else if constexpr (is_destination_supported) { + ElementD const* ptr_D = nullptr; + + // tma_store_c should be a gmem_tensor, second argument should be a stride + + Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], + prob_shape, + prob_stride); } } - template + template CUTLASS_DEVICE void tensormaps_perform_update( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& params, cute::TmaDescriptor const* tensormap, - int32_t next_batch) { + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch, + int32_t warp_group_idx) { if (cute::elect_one_sync()) { - // Bringing tensormaps to smem - tensormaps_fetch_to_smem(shared_tensormap, tensormap); // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormap, params, next_batch); + tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties( + shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx); + } } } @@ -972,16 +1113,18 @@ class CollectiveEpilogue< CUTLASS_DEVICE void tensormaps_cp_fence_release( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, cute::TmaDescriptor const* tensormap, - [[maybe_unused]] uint32_t lane_predicate) { + [[maybe_unused]] uint32_t lane_predicate, + int32_t warp_group_idx = 0) { // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { - if (not cute::is_void_v) { - tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); + if constexpr (is_source_supported) { + tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C); } - } else { - tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); + } + else if constexpr (is_destination_supported) { + tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]); } } @@ -990,10 +1133,11 @@ class CollectiveEpilogue< void tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { if constexpr (IsLoad) { - if (not cute::is_void_v) { + if constexpr (not cute::is_void_v) { cute::tma_descriptor_fence_acquire(tensormap); } - } else { + } + else { cute::tma_descriptor_fence_acquire(tensormap); } } diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 9f9576b417..e96f413445 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -51,7 +51,21 @@ struct PtrArrayNoSmemWarpSpecialized {}; struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; -struct PtrArrayTmaWarpSpecializedCooperative {}; + +struct PtrArrayTmaWarpSpecializedCooperative { + static constexpr int NumEpilogueWarpGroups = 2; +}; + +// Standard warp specialized epilogue +struct PtrArrayTmaWarpSpecialized { + static constexpr int NumEpilogueWarpGroups = 1; +}; + +// Pingpong kernel epilogue +struct PtrArrayTmaWarpSpecializedPingpong { + static constexpr int NumEpilogueWarpGroups = 2; +}; + // DEPRECATED schedules, will be removed in next release struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; @@ -151,7 +165,8 @@ template< int StagesD_, int FragmentSize_, bool ReuseSmemC_, - bool DelayTmaStore_ + bool DelayTmaStore_, + int NumEpilogueWarpGroups_ > struct Sm90PtrArrayTmaWarpSpecialized { constexpr static int StagesC = StagesC_; @@ -159,6 +174,7 @@ struct Sm90PtrArrayTmaWarpSpecialized { constexpr static int FragmentSize = FragmentSize_; constexpr static bool ReuseSmemC = ReuseSmemC_; constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; }; // DEPRECATED policies, will be removed in next release diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index a01288778c..0bfacf34cc 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -32,6 +32,7 @@ #pragma once #include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 3f43f60d6e..ece5ac542e 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -179,6 +179,89 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombinationPtrArray = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcastPtrArray>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int>; + using StrideBeta = Stride<_0,_0,int>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = activation(alpha * acc + beta * C) template< template class ActivationFn, diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 4eb326b3dd..aedacb552e 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/collective/detail.hpp" #include "cute/tensor.hpp" #include "sm90_visitor_tma_warpspecialized.hpp" @@ -514,7 +515,8 @@ struct Sm90ScalarBroadcast { if (params_ptr->scalar_ptrs[0] != nullptr) { scalar = params_ptr->scalar_ptrs[0][l_offset]; - } else { + } + else { // batch stride is ignored for nullptr fallback scalar = params_ptr->scalars[0]; } @@ -541,6 +543,169 @@ struct Sm90ScalarBroadcast { } }; +// Scalar broadcast +// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors +template< + class Element, + class StrideMNL = Stride<_0,_0,_0>, + int BroadcastCount = 1, + template class ReductionFn = multiplies +> +struct Sm90ScalarBroadcastPtrArray { + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); + + struct SharedStorage { }; + + struct Arguments { + Element scalars[BroadcastCount] = {}; + Element const* scalar_ptrs[BroadcastCount] = {}; + Element const* const* scalar_ptr_arrays[BroadcastCount] = {}; + StrideMNL dScalar[BroadcastCount] = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + // producer load is needed if Element is not void and we have multiple scalars + return !cute::is_void_v and size<2>(params_ptr->dScalar[0]) != 0; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + // This must be called after update_scalar is called + CUTLASS_DEVICE bool + is_zero() const { + return scalar == Element(0); + } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcastPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { + // Get the scalar for non-batched broadcast + if (size<2>(params_ptr->dScalar[0]) == 0) { + update_scalar(); + } + } + + Element scalar; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Element scalar) + : scalar(scalar) {} + + Element scalar; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_scalar; + frg_scalar.fill(scalar); + + return frg_scalar; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return ConsumerStoreCallbacks(scalar); + } + +private: + CUTLASS_DEVICE void + update_scalar(int l_coord = 0) { + int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); + + if (params_ptr->scalar_ptr_arrays[0] != nullptr) { + scalar = *(params_ptr->scalar_ptr_arrays[0][l_offset]); + } + else if (params_ptr->scalar_ptrs[0] != nullptr) { + scalar = params_ptr->scalar_ptrs[0][l_offset]; + } + else { + // batch stride is ignored for nullptr fallback + scalar = params_ptr->scalars[0]; + } + + // Do reduction over multiple broadcasts if necessary + ReductionFn reduction_fn; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < BroadcastCount; ++i) { + + if (params_ptr->scalar_ptr_arrays[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][rest_l_offset])); + } + if (params_ptr->scalar_ptrs[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); + } + else { + // batch stride is ignored for nullptr fallback + scalar = reduction_fn(scalar, params_ptr->scalars[i]); + } + } + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace detail { diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 25b1f84832..0b3ecb15c6 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -31,6 +31,10 @@ #pragma once #include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/gemm/collective/collective_mma_decl.hpp" +#include "cutlass/gemm/collective/collective_builder_decl.hpp" // SM90 Collective Builders should be used only starting CUDA 12.0 #if (__CUDACC_VER_MAJOR__ >= 12) @@ -177,10 +181,12 @@ struct CollectiveBuilder< StageCountType, KernelScheduleType, cute::enable_if_t< - (cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v) && + (cute::is_any_of_v) && not detail::is_use_rmem_A()> > { static_assert(is_static::value); @@ -191,10 +197,12 @@ struct CollectiveBuilder< static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); - static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v); + static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), - "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); + "KernelPtrArrayTmaWarpSpecialized[Cooperative|Pingpong] is only compatible with FP8 FastAccum version right now."); // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; @@ -203,8 +211,10 @@ struct CollectiveBuilder< static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - using AtomLayoutMNK = cute::conditional_t< - cute::is_same_v || IsArrayOfPointersGemm, + static constexpr bool IsCooperative = cute::is_any_of_v; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< @@ -218,7 +228,10 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(TensorMapStorage); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = cute::conditional_t, @@ -505,10 +518,12 @@ struct CollectiveBuilder< StageCountType, KernelScheduleType, cute::enable_if_t< - cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v> + cute::is_any_of_v> > { static_assert(is_static::value); static_assert(is_static::value); @@ -526,10 +541,15 @@ struct CollectiveBuilder< static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v); - using AtomLayoutMNK = cute::conditional_t || - IsArrayOfPointersGemm, - Layout>, Layout>>; + static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v; + + static constexpr bool IsCooperative = cute::is_any_of_v; + + using AtomLayoutMNK = cute::conditional_t>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); @@ -542,7 +562,11 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(TensorMapStorage); + static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = cute::conditional_t, diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 4f2837d17e..75d7bb39e9 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -623,56 +623,42 @@ struct CollectiveMma< // CUTLASS_DEVICE auto - tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const { + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; if (cute::elect_one_sync()) { - // Bringing tensormaps from params to gmem for modification later + // Bringing tensormaps from params to smem for modification later Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); - Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); - copy(recast(pA_tensormap), recast(gA_tensormap)); - copy(recast(pB_tensormap), recast(gB_tensormap)); + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); } + __syncwarp(); return cute::make_tuple(tma_desc_a, tma_desc_b); } - // Bringing tensormaps to smem (to be done by single thread) - template - CUTLASS_DEVICE - void - tensormaps_fetch_to_smem( - TensorMapStorage& shared_tensormap, - cute::tuple const& input_tensormaps) const { - Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{}); - Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{}); - Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{}); - Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{}); - - copy(recast(gA_tensormap), recast(sA_tensormap)); - copy(recast(gB_tensormap), recast(sB_tensormap)); - - cp_async_fence(); - cp_async_wait<0>(); - } - // Replace address for the global tensor (to be done by single thread) CUTLASS_DEVICE void tensormaps_replace_global_address( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& mainloop_params, int32_t next_batch) { // Replacing global_address for the next batch - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A, + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, mainloop_params.ptr_A[next_batch]); - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B, + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, mainloop_params.ptr_B[next_batch]); } @@ -681,7 +667,7 @@ struct CollectiveMma< CUTLASS_DEVICE void tensormaps_replace_global_tensor_properties( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& mainloop_params, int32_t next_group, ProblemShape_MNKL problem_shape_mnkl) { @@ -716,10 +702,10 @@ struct CollectiveMma< stride = (stride * sizeof_bits_v) / 8; } - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A, + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, prob_shape_A, prob_stride_A); - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_B, + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, prob_shape_B, prob_stride_B); } @@ -728,21 +714,19 @@ struct CollectiveMma< CUTLASS_DEVICE void tensormaps_perform_update( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, Params const& mainloop_params, cute::tuple const& input_tensormaps, ProblemShape_MNKL problem_shape_mnkl, int32_t next_batch) { if (cute::elect_one_sync()) { - // Bringing tensormaps to smem - tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps); // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch); + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); if constexpr (IsGroupedGemmKernel) { // Replacing global dims and strides for the next batch - tensormaps_replace_global_tensor_properties(shared_tensormap, + tensormaps_replace_global_tensor_properties(shared_tensormaps, mainloop_params, next_batch, problem_shape_mnkl); } } @@ -752,11 +736,11 @@ struct CollectiveMma< CUTLASS_DEVICE void tensormaps_cp_fence_release ( - TensorMapStorage& shared_tensormap, + TensorMapStorage& shared_tensormaps, cute::tuple const& input_tensormaps) { // Entire warp must do this (i.e. it's aligned) - tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A); - tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B); + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); } // The entire warp must call this function collectively (that is, the instructions are aligned) diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 2e820b6136..c1c2308b9d 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -98,6 +98,7 @@ struct KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedPingpong { }; ////////////////////////////////////////////////////////////////////////////// @@ -111,6 +112,7 @@ struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum : KernelPtrArrayTmaWarpSpecializedPingpong { }; // Policies to opt into mixed type GEMMs struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; @@ -286,8 +288,9 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized { using ArchTag = arch::Sm90; using Schedule = KernelSchedule; static_assert( - cute::is_base_of_v, - "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies"); + cute::is_base_of_v || + cute::is_base_of_v, + "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies"); }; ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index b682be867d..6c7b89a241 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -61,5 +61,6 @@ struct IsCutlass3ArrayKernel or "); + + static_assert(cute::is_base_of_v); + // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; @@ -119,8 +124,9 @@ class GemmUniversal< using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaThreads = CUTE_STATIC_V(size(TiledMma{})); + static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; /// Register requirement for Load and Math WGs @@ -215,11 +221,11 @@ class GemmUniversal< workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count); + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); void* mainloop_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count); + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used @@ -275,9 +281,6 @@ class GemmUniversal< args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { @@ -286,6 +289,9 @@ class GemmUniversal< sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); } + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); @@ -363,6 +369,12 @@ class GemmUniversal< static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); static_assert(size<0>(TileShape{}) >= 128, "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + static_assert(NumMmaWarpGroups == 2, "Cooperative kernels currently only support NumMmaWarpGroups == 2"); + + if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v) { + static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups, + "Tiled MmA does not match expected warp groups performing the epilogue"); + } static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); @@ -391,7 +403,8 @@ class GemmUniversal< int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; int mma_thread_idx = thread_idx % size(TiledMma{}); - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto warp_group_idx = canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); int lane_predicate = cute::elect_one_sync(); uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); @@ -466,7 +479,9 @@ class GemmUniversal< // Get the appropriate blocks for this thread block -- potential for thread block locality TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); TileScheduler scheduler{params.scheduler}; @@ -484,7 +499,7 @@ class GemmUniversal< } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); // Prepare and partition the input tensors. Expects a tuple of tensors where: // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) @@ -510,7 +525,7 @@ class GemmUniversal< int32_t const sm_count = params.hw_info.sm_count; // Fetch a copy of tensormaps for the CTA - auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, sm_count, sm_idx); + auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx); // Update tensormap for the initial batch for the CTA if (work_tile_info.is_valid()) { @@ -578,7 +593,7 @@ class GemmUniversal< if (work_tile_info.is_valid() && did_batch_change) { curr_batch = next_batch; if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), Int<1>{}); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch); } // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates // Since this state is waiting for loads to finish, it must start in the inverted phase. @@ -610,7 +625,7 @@ class GemmUniversal< int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); int32_t const sm_count = params.hw_info.sm_count; - auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, sm_count, sm_idx)); + auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx)); bool did_batch_change = true; constexpr bool IsEpiLoad = true; @@ -620,23 +635,26 @@ class GemmUniversal< shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, - work_tile_info.L_idx + problem_shape_MNKL, + work_tile_info.L_idx, + 0 ); // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); } load_order_barrier.wait(); while (work_tile_info.is_valid()) { int32_t curr_batch = work_tile_info.L_idx; - bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + // Get next work tile + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); - if (compute_epilogue) { + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); } // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape @@ -649,6 +667,8 @@ class GemmUniversal< collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); } + bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; + epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -660,36 +680,34 @@ class GemmUniversal< shared_storage.tensors.epilogue, epi_load_tensormap, work_tile_info.reduction_subtile_idx(), - true // return state prior to last advance + wait ); } - // Get next work tile - work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; did_batch_change = curr_batch != work_tile_info.L_idx; if (work_tile_info.is_valid() && did_batch_change) { - // Wait for TMA load to finish before updating - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = - {epi_load_pipe_producer_state.index(), !epi_load_pipe_producer_state.phase(), epi_load_pipe_producer_state.count()}; - - epi_load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); - - collective_epilogue.tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_load_tensormap, - work_tile_info.L_idx - ); - - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate); - } + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } - if(compute_epilogue) { - epi_load_pipe_producer_state.advance(1); + // tensormap update + { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + 0 + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + } } } // Scheduler work fetch loop @@ -702,32 +720,43 @@ class GemmUniversal< else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1; + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); int32_t const sm_count = params.hw_info.sm_count; // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it bool do_store_tail = false; // Get a copy of tensormaps - auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, sm_count, sm_idx)); + auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); bool did_batch_change = true; constexpr bool IsEpiLoad = false; if (work_tile_info.is_valid()) { - collective_epilogue.tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_store_tensormap, - work_tile_info.L_idx - ); + if (warp_idx_in_warp_group == 0) { + + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate); + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, + lane_predicate, + consumer_warp_group_idx); + } } while (work_tile_info.is_valid()) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); } int32_t curr_batch = work_tile_info.L_idx; @@ -743,6 +772,10 @@ class GemmUniversal< // // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + static_assert(cute::is_any_of_v, + detail::PersistentTileSchedulerSm90>); if(TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { collective_mainloop.mma( mainloop_pipeline, @@ -764,18 +797,16 @@ class GemmUniversal< // Update starting mainloop pipeline state for the next tile mainloop_pipe_consumer_state.advance(work_k_tile_count); } - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; // Perform reduction across splits, if needed TileScheduler::fixup( params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + if (did_batch_change) { + collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + } - if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); - } + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { // Epilogue and write to gD auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = @@ -804,20 +835,31 @@ class GemmUniversal< did_batch_change = curr_batch != work_tile_info.L_idx; if (work_tile_info.is_valid() && did_batch_change) { - collective_epilogue.tensormaps_perform_update( - shared_storage.tensormaps.epilogue, - params.epilogue, - epi_store_tensormap, - work_tile_info.L_idx - ); + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + if (warp_idx_in_warp_group == 0) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); - // Converge before issuing tensormap fence release since fence is aligned - __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate); + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, + lane_predicate, + consumer_warp_group_idx); + } } } // Scheduler work fetch loop + // Cooperative only needs TMA to complete at the very end of the kernel if (do_store_tail) { collective_epilogue.store_tail( epi_load_pipeline, @@ -829,7 +871,6 @@ class GemmUniversal< } // Consumer Warp Groups End #endif } - }; /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp new file mode 100644 index 0000000000..491ec0ec6c --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -0,0 +1,949 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/tensor.hpp" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t> +> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + static_assert(cute::is_base_of_v); + + static constexpr bool IsGdcEnabled = false; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + static_assert(cute::is_void_v, + "Ptr-Array Pingpong and Grouped Gemm Pingpong kernel only supports the default scheduler."); + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TileScheduler = cute::conditional_t::Scheduler, + typename detail::TileSchedulerSelector< + void, ArchTag, TileShape, ClusterShape>::Scheduler>; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 2; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue + static constexpr uint32_t StagesPerMathWarpGroup = 2; + + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; + + using MathWarpGroupOrderBarrierSharedStorage = cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage< + MathWarpGroupOrderBarrier::SequenceDepth, + MathWarpGroupOrderBarrier::SequenceLength>; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = MathWarpGroupOrderBarrierSharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + } pipelines; + + struct TensorMapStorage : cute::aligned_struct<128> { + using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; + + alignas(128) MainloopTensorMapStorage mainloop; + alignas(128) EpilogueTensorMapStorage epilogue; + } tensormaps; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + ProblemShape problem_shapes = args.problem_shape; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + } + else { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes.get_host_problem_shape(), TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + } + + return { + args.mode, + problem_shapes, + CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), + hw_info, + scheduler, + workspace + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + if constexpr (IsGroupedGemmKernel) { + // Group GEMM currently only supports rank-3 problem shapes + implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); + } else { + implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + else { + grid_shape = TileScheduler::get_grid_shape(params.problem_shape.get_host_problem_shape(), TileShape{}, ClusterShape{}, params.hw_info, args); + } + return grid_shape; + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(size(TiledMma{}) == 128, "Pingpong kernel must have TiledMMA operating using 128 threads."); + static_assert(NumMmaWarpGroups == 2, "Pingpong kernels currently only support NumMmaWarpGroups == 2"); + + if constexpr (cutlass::epilogue::collective::detail::sm90_is_ptr_array_tma_dispatch_policy_v) { + static_assert(NumMmaWarpGroups == CollectiveEpilogue::NumEpilogueWarpGroups, + "Tiled MmA does not match expected warp groups performing the epilogue"); + } + + static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_idx = canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Note: Tma Descriptor Prefetch (from either const or param) is not applicable here + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = warp_group_idx - static_cast(WarpGroupRole::Consumer0); + params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + const auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + const auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + const auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + TileScheduler scheduler{params.scheduler}; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + if (not work_tile_info.is_valid()) { + // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups + return; + } + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + + if (warp_group_role == WarpGroupRole::Consumer1) { + // Advance 2nd Math WG to the next work tile for the startup + const auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + if (!work_tile_info.is_valid()) { + return; + } + + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx; + int32_t const mock_l_coord = 0; + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; + + // Fetch a copy of tensormaps for the CTA + auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, shared_storage.tensormaps.mainloop, sm_count, sm_idx); + + // Update tensormap for the initial batch for the CTA + if (work_tile_info.is_valid()) { + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape_MNKL, + curr_batch + ); + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (i.e. it's aligned) + collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); + } + + bool do_load_order_arrive = true; + bool did_batch_change = true; + while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + if (did_batch_change) { + collective_mainloop.tensormaps_fence_acquire(input_tensormaps); + } + + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + input_tensormaps, + blk_coord, + k_tile_iter, work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + // Wait for the last TMA stage to complete loading, before issuing tensormap updates + mainloop_pipe_producer_state.advance(work_k_tile_count - 1); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx + did_batch_change = next_batch != curr_batch; + if (work_tile_info.is_valid() && did_batch_change) { + curr_batch = next_batch; + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch); + } + // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates + // Since this state is waiting for loads to finish, it must start in the inverted phase. + typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state = + {mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()}; + mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state); + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape_MNKL, + curr_batch + ); + // Ensure warp is converged before issuing tensor replace + __syncwarp(); + // Entire warp must do this (i.e. it's aligned) + collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); + } + // Advance the producer state for the last remaining stage that was being waited for above + mainloop_pipe_producer_state.advance(1); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; + + auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx)); + + bool did_batch_change = true; + constexpr bool IsEpiLoad = true; + + if (work_tile_info.is_valid()) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + 0 + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + } + + load_order_barrier.wait(); + + while (work_tile_info.is_valid()) { + int32_t curr_batch = work_tile_info.L_idx; + + // Get next work tile + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + if (did_batch_change) { + collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); + } + + bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue, + epi_load_tensormap, + work_tile_info.reduction_subtile_idx(), + wait + ); + } + + work_tile_info = next_work_tile_info; + did_batch_change = curr_batch != work_tile_info.L_idx; + + if (work_tile_info.is_valid() && did_batch_change) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + + // tensormap update + { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + 0 + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + } + } + + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + cutlass::arch::warpgroup_reg_alloc(); + + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1; + + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + // Get a copy of tensormaps + auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, shared_storage.tensormaps.epilogue, sm_count, sm_idx, consumer_warp_group_idx)); + + bool did_batch_change = true; + constexpr bool IsEpiLoad = false; + + if (work_tile_info.is_valid()) { + + if (warp_idx_in_warp_group == 0) { + + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, + lane_predicate, + consumer_warp_group_idx); + } + } + + while (work_tile_info.is_valid()) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + + int32_t curr_batch = work_tile_info.L_idx; + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + static_assert(cute::is_any_of_v, + detail::PersistentTileSchedulerSm90>); + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + + math_wg_order_barrier.wait(); + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + math_wg_order_barrier.arrive(); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + work_k_tile_count + ); + + math_wg_order_barrier.wait(); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); + + if (did_batch_change) { + collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + } + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue, + epi_store_tensormap, + work_tile_info.reduction_subtile_idx() + ); + + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + auto next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + + // Skip a tile for pingpong + if (work_tile_info.is_valid()) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + mainloop_pipe_consumer_state.advance(work_k_tile_count); + + // Go to next tile + auto next_next_work_tile_info = scheduler.fetch_next_work(work_tile_info); + + work_tile_info = next_next_work_tile_info; + } + + did_batch_change = curr_batch != work_tile_info.L_idx; + if (work_tile_info.is_valid() && did_batch_change) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + } + if (warp_idx_in_warp_group == 0) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape_MNKL, + work_tile_info.L_idx, + consumer_warp_group_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, + lane_predicate, + consumer_warp_group_idx); + } + } + + // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels + // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives + // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. + auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] = + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state + ); + + // Update starting load/store pipeline states for the next tile + // state has already been incremented by 1 tile in collective calls, advance once again for ping pong + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + + } // Scheduler work fetch loop + } // Consumer Warp Groups End +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel \ No newline at end of file diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index a70ce542d0..348b185c74 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -338,12 +338,14 @@ cutlass_test_unit_add_executable( cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_ptr_array sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu ) # Group Gemm test cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_group_gemm sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu ) # Fused epilogue tests diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index e2d3f2d06a..085d3e7404 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -1005,7 +1005,7 @@ struct HostCollectiveEpilogue { stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); } - static_assert(!IsGroupGemm or (IsGroupGemm and IsAuxOutEnabled)); + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled)); if constexpr (IsAuxOutEnabled) { for (int32_t i = 0; i < L; ++i) { @@ -1323,8 +1323,16 @@ struct HostCollectiveEpilogue { cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); - auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensors_Aux[batch].host_data() : references_Aux[batch].host_data()), - cute::make_layout(cute::make_shape(M, N, 1), stride_Aux)); + auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux); + auto Aux = [&]() { + auto ptr = recast_ptr(nullptr); + if (IsAuxInEnabled) { + ptr = detail::make_iterator(tensors_Aux[batch].host_data()); + } else if (IsAuxOutEnabled) { + ptr = detail::make_iterator(references_Aux[batch].host_data()); + } + return cute::make_tensor(ptr, Aux_layout); + }(); auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu index b93d936865..2a6d2339ec 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu @@ -78,7 +78,69 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // M using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm, 128x128x64_2x2x1_direct_store) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch @@ -115,6 +177,8 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using Gemm = cutlass::gemm::device::GemmUniversalAdapter; bool result = TestAll(1.0, 1.0); EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); } #endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu new file mode 100644 index 0000000000..09be8a490d --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu @@ -0,0 +1,184 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array Ping-pong scheduler GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm_pingpong, 128x128x64_2x2x1) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm_pingpong, 128x128x64_2x2x1_direct_store) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) \ No newline at end of file diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu index dc581acf7f..53748dc81c 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu @@ -115,9 +115,11 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using Gemm = cutlass::gemm::device::GemmUniversalAdapter; bool result = TestAll(1.0, 1.0); EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); } -TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_NoSmemEpi) { +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_direct_store) { // A matrix configuration using ElementA = cutlass::half_t; // Element type for A matrix operand @@ -173,6 +175,7 @@ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< using namespace test::gemm::device; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(TestAll(1.0, 1.0)); EXPECT_TRUE(TestAll(1.0, 0.0)); } diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu new file mode 100644 index 0000000000..5b91825c1d --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array_pingpong, 128x128x64_2x2x1) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); + result = TestAll(1.0, 0.0); + EXPECT_TRUE(result); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array_pingpong, 128x128x64_2x2x1_direct_store) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(TestAll(1.0, 1.0)); + EXPECT_TRUE(TestAll(1.0, 0.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) \ No newline at end of file From 3a8c01a18b24c35b216922481ac762496720a99d Mon Sep 17 00:00:00 2001 From: John Shumway Date: Wed, 11 Sep 2024 10:33:56 -0700 Subject: [PATCH 20/53] Prefix a member template name with the template keyword. (#1796) Fixes llvm buld error. --- .../epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 4b2157b6c7..6d173bfe23 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -640,7 +640,7 @@ class CollectiveEpilogue< tRS_rC, thread_idx ); - auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); From 9f68995de585a883e3ff6b1d0347ea02aff55451 Mon Sep 17 00:00:00 2001 From: reed Date: Mon, 16 Sep 2024 23:55:09 +0800 Subject: [PATCH 21/53] =?UTF-8?q?add=20publication:=20=E2=80=98EVT:=20Acce?= =?UTF-8?q?lerating=20Deep=20Learning=20Training=20with=20Epilogue=20Visit?= =?UTF-8?q?or=20Tree=E2=80=99=20(#1526)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com> --- PUBLICATIONS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/PUBLICATIONS.md b/PUBLICATIONS.md index 65d1f08e07..04d4cd0a14 100644 --- a/PUBLICATIONS.md +++ b/PUBLICATIONS.md @@ -2,6 +2,8 @@ ## 2024 +- ["EVT: Accelerating Deep Learning Training with Epilogue Visitor Tree"](https://dl.acm.org/doi/10.1145/3620666.3651369). Zhaodong Chen, Andrew Kerr, Richard Cai, Jack Kosaian, Haicheng Wu, Yufei Ding, and Yuan Xie. _Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, April 2024. + - ["Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level"](https://arxiv.org/abs/2403.04690). Ali Hassani, Wen-Mei Hwu, Humphrey Shi. _arXiv_, March 2024. ## 2023 From 1ebda1ccef14df97da9ed098bd20e0c8520d6972 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 17 Sep 2024 00:38:42 +0800 Subject: [PATCH 22/53] Fix MMA promotion interval assertions (#1641) --- .../collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index c281d4f5f7..6c02979996 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -246,8 +246,8 @@ struct CollectiveMma< implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); - /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */ - implementable = implementable && (args.mma_promotion_interval % 4 == 0); + /* MMA promotion interval should be a multiple of the number of MMA instructions issued by each mainloop iteration. */ + implementable = implementable && (args.mma_promotion_interval % (size<2>(TileShape{})() / TiledMma().template tile_size_mnk<2>()()) == 0); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); From 2991ce18d3121a10e622f5adcccd5859f83190cd Mon Sep 17 00:00:00 2001 From: reed Date: Wed, 18 Sep 2024 22:37:24 +0800 Subject: [PATCH 23/53] Add print_svg for mma (#1733) * add print_svg for mma * correct the code indentation --- include/cute/atom/mma_atom.hpp | 177 +++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 6dc826ef2f..2358dd568f 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -928,6 +928,183 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and printf(latex_footer); } +// MNK MMA Layout to SVG -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175", + "255,175,175", "210,210,255", "210,255,210", + "255,255,210", "255,210,210"}; + + const int cell_width = 20; + const int cell_height = 20; + + const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width; + const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height; + + // header + printf("\n", + page_width, page_height); + + // C + int c_base_x = (size<1>(A) + 2) * cell_width; + int c_base_y = (size<1>(B) + 2) * cell_height; + for (int m = 0; m < cute::size<0>(C); ++m) { + for (int n = 0; n < cute::size<1>(C); ++n) { + + int thrid = C(m, n) % size(TC); + int val_idx = C(m, n) / size(TC); + int thr_idx = TC(thrid); + + int x = n * cell_width + c_base_x; + int y = m * cell_height + c_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // A + int a_base_x = cell_width; + int a_base_y = (size<1>(B) + 2) * cell_height; + for (int m = 0; m < size<0>(A); ++m) { + for (int k = 0; k < size<1>(A); ++k) { + int thrid = A(m, k) % size(TA); + int val_idx = A(m, k) / size(TA); + int thr_idx = TA(thrid); + + int x = k * cell_width + a_base_x; + int y = m * cell_height + a_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // B + int b_base_x = (size<1>(A) + 2) * cell_width; + int b_base_y = cell_height; + for (int n = 0; n < size<0>(B); ++n) { + for (int k = 0; k < size<1>(B); ++k) { + int thrid = B(n, k) % size(TB); + int val_idx = B(n, k) / size(TB); + int thr_idx = TB(thrid); + + int x = n * cell_width + b_base_x; + int y = k * cell_height + b_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // A labels + for (int m = 0; m < size<0>(A); ++m) { + int x = cell_width / 2; + int y = m * cell_height + cell_height / 2 + a_base_y; + printf("%d\n", + x, y, m); + } + for (int k = 0; k < size<1>(A); ++k) { + int x = cell_width + k * cell_width + cell_width / 2; + int y = -cell_height / 2 + a_base_y; + printf("%d\n", + x, y, k); + } + + // B labels + for (int n = 0; n < size<0>(B); ++n) { + int x = b_base_x + cell_width * n + cell_width / 2; + int y = cell_height / 2; + printf("%d\n", + x, y, n); + } + for (int k = 0; k < size<1>(B); ++k) { + int x = b_base_x - cell_width / 2; + int y = cell_height * (k + 1) + cell_height / 2; + printf("%d\n", + x, y, k); + } + + // footer + printf(""); +} + +template +CUTE_HOST_DEVICE +void +print_svg(MMA_Atom const &mma_atom) { + print_svg(make_tiled_mma(mma_atom)); +} + +template +CUTE_HOST_DEVICE +void +print_svg(TiledMMA const &mma) { + auto layout_and_thrid_C = mma.get_layoutC_MN(); + auto layoutC_MN = get<0>(layout_and_thrid_C); + auto thrID_C = get<1>(layout_and_thrid_C); + + auto layout_and_thrid_A = mma.get_layoutA_MK(); + auto layoutA_MK = get<0>(layout_and_thrid_A); + auto thrID_A = get<1>(layout_and_thrid_A); + + auto layout_and_thrid_B = mma.get_layoutB_NK(); + auto layoutB_NK = get<0>(layout_and_thrid_B); + auto thrID_B = get<1>(layout_and_thrid_B); + + print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B); +} + } // namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// From 44dae8b90ef232ea663727470dfbbe9daff6972d Mon Sep 17 00:00:00 2001 From: Wenlei Bao <142055114+wenlei-bao@users.noreply.github.com> Date: Thu, 19 Sep 2024 08:40:30 -0700 Subject: [PATCH 24/53] Adjust profiler space for SM89 (#1553) --- python/cutlass_library/generator.py | 51 ++++++++++++++++++----------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index c736551432..9f327154a6 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -4881,7 +4881,8 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): return layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor) + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) ] math_instructions = [ @@ -4935,43 +4936,49 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): for math_inst in math_instructions: tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), ] data_types = [ @@ -4981,6 +4988,12 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): DataType.f32, math_inst.element_accumulator ], + [ + math_inst.element_a, + math_inst.element_b, + DataType.bf16, + math_inst.element_accumulator + ], ] operations = [] From e2b078992702d2671d9d2ef5c394fe679338a075 Mon Sep 17 00:00:00 2001 From: Junkai-Wu Date: Wed, 25 Sep 2024 23:28:10 +0800 Subject: [PATCH 25/53] Add some can implement rules of hopper convolution. (#1835) --- ..._implicit_gemm_gmma_ss_warpspecialized.hpp | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index 13bb7c515c..84fd37b47e 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -375,6 +375,61 @@ struct CollectiveConv< return false; } + if (is_im2col_A || is_im2col_B) { + // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] + constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); + } + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) + if constexpr (ConvOp == conv::Operator::kWgrad) { + + const auto & input_shape = problem_shape.shape_A; + const auto & input_stride = problem_shape.stride_A; + + implementable &= input_stride[ProblemShape::RankT - 1] == 1; + int input_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + input_shape_size *= input_shape[i + 1]; + implementable &= input_stride[i] == input_shape_size; + } + + const auto & output_shape = problem_shape.shape_C; + const auto & output_stride = problem_shape.stride_C; + + implementable &= output_stride[ProblemShape::RankT - 1] == 1; + int output_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + output_shape_size *= output_shape[i + 1]; + implementable &= output_stride[i] == output_shape_size; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); + return false; + } + } + + // Conv kernels only support cross correlation mode currently. + implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); + return false; + } + if (problem_shape.groups > 1) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); return false; From b27c49e84a8d1a2a1ec3341b7ad91164ce65c70e Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 8 Oct 2024 00:38:32 +0800 Subject: [PATCH 26/53] Fix cute doc (#1529) --- media/docs/cute/0x_gemm_tutorial.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/media/docs/cute/0x_gemm_tutorial.md b/media/docs/cute/0x_gemm_tutorial.md index 533d4b4be0..451e52e4ed 100644 --- a/media/docs/cute/0x_gemm_tutorial.md +++ b/media/docs/cute/0x_gemm_tutorial.md @@ -150,7 +150,7 @@ This `local_tile` is simply shorthand for 1. apply the tiler via [`zipped_divide`](./02_layout_algebra.md#zipped-tiled-flat-divides) ```cpp // ((BLK_M,BLK_K),(m,k)) -Tensor gA_mk = zipped_divide(gA, select<0,2>(cta_tiler)); +Tensor gA_mk = zipped_divide(mA, select<0,2>(cta_tiler)); ``` 2. apply the coord to the second mode, the "Rest" mode, to extract out the correct tiles for this CTA. ```cpp From 477a67731736315a8c87cb60780a95d6f2985eec Mon Sep 17 00:00:00 2001 From: Alexander Zinoviev <8257131+alexander-zinoviev@users.noreply.github.com> Date: Mon, 7 Oct 2024 09:39:11 -0700 Subject: [PATCH 27/53] Fix typos in test/unit/conv/cache_testbed_output.h (#1652) Co-authored-by: Alexander Zinoviev --- test/unit/conv/cache_testbed_output.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/unit/conv/cache_testbed_output.h b/test/unit/conv/cache_testbed_output.h index 4f3981e83b..eec63a7977 100644 --- a/test/unit/conv/cache_testbed_output.h +++ b/test/unit/conv/cache_testbed_output.h @@ -659,7 +659,7 @@ inline CachedTestKey CreateCachedConv2dTestKey( ElementCompute alpha, ElementCompute beta, cutlass::TensorView A, - cutlass::TensorView B, + cutlass::TensorView B, cutlass::TensorView C ) { @@ -711,7 +711,7 @@ inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey( ElementCompute alpha, ElementCompute beta, cutlass::TensorView A, - cutlass::TensorView B, + cutlass::TensorView B, cutlass::TensorView C ) { @@ -763,7 +763,7 @@ inline CachedTestKey CreateCachedConv2dWithReductionTestKey( ElementCompute alpha, ElementCompute beta, cutlass::TensorView A, - cutlass::TensorView B, + cutlass::TensorView B, cutlass::TensorView C ) { @@ -814,7 +814,7 @@ inline CachedTestKey CreateCachedConv3dTestKey( ElementCompute alpha, ElementCompute beta, cutlass::TensorView A, - cutlass::TensorView B, + cutlass::TensorView B, cutlass::TensorView C ) { From 0837a2a00a02cf53ff2a36bb9630a39dbf772b5c Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 8 Oct 2024 00:39:59 +0800 Subject: [PATCH 28/53] Fix typo in comment (#1787) --- .../fusion/sm90_visitor_compute_tma_warpspecialized.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 8f5ceb5489..2ae10a688a 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -78,7 +78,7 @@ using namespace detail; // the template argument. // // template -// struct FooHomogeneous : public Foo {}; +// struct FooHomogeneous : public Foo {}; // template< template class ComputeFn, From cc3c29a81a140f7b97045718fb88eb0664c37bd7 Mon Sep 17 00:00:00 2001 From: Yujia Zhai Date: Wed, 9 Oct 2024 12:33:27 -0700 Subject: [PATCH 29/53] CUTLASS 3.6.0 (#1850) * v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai Co-authored-by: Haicheng Wu --- CHANGELOG.md | 20 + CMakeLists.txt | 66 +- PUBLICATIONS.md | 2 + README.md | 70 +- cmake/CTestTestfile.configure.cmake | 2 - cmake/CTestTestfile.test.configure.cmake | 8 +- cmake/googletest.cmake | 3 +- examples/35_gemm_softmax/gemm_softmax.cu | 8 +- .../48_hopper_warp_specialized_gemm.cu | 12 +- .../gather_gemm.hpp | 2 +- .../53_hopper_gemm_permute/permute_traits.hpp | 21 +- .../55_hopper_int4_fp8_gemm.cu | 701 + .../55_hopper_mixed_dtype_gemm.cu | 154 +- .../55_hopper_mixed_dtype_gemm/CMakeLists.txt | 13 +- examples/55_hopper_mixed_dtype_gemm/README.md | 2 + .../packed_scale.hpp | 132 + .../56_hopper_ptr_array_batched_gemm.cu | 3 +- .../57_hopper_grouped_gemm.cu | 6 +- .../CMakeLists.txt | 4 + .../61_hopper_gemm_with_topk_and_softmax.cu | 534 + .../CMakeLists.txt | 32 + .../62_hopper_sparse_gemm.cu | 596 + examples/62_hopper_sparse_gemm/CMakeLists.txt | 36 + .../63_hopper_gemm_with_weight_prefetch.cu | 500 + .../CMakeLists.txt | 36 + .../README.md | 82 + .../collective/builder.hpp | 215 + .../collective/dispatch_policy_extra.hpp | 61 + ..._gmma_ss_warpspecialized_with_prefetch.hpp | 867 + .../gemm_with_weight_prefetch_commandline.hpp | 117 + ...gemm_tma_warpspecialized_with_prefetch.hpp | 561 + .../pipeline/prefetch_pipeline_sm90.hpp | 161 + examples/CMakeLists.txt | 3 + examples/cute/tutorial/tiled_copy.cu | 4 +- include/cute/algorithm/clear.hpp | 6 +- include/cute/algorithm/cooperative_copy.hpp | 12 +- include/cute/algorithm/cooperative_gemm.hpp | 4 +- include/cute/algorithm/copy.hpp | 12 +- include/cute/algorithm/functional.hpp | 7 +- include/cute/algorithm/prefetch.hpp | 8 +- include/cute/algorithm/tuple_algorithms.hpp | 130 +- include/cute/arch/cluster_sm90.hpp | 4 +- include/cute/arch/config.hpp | 50 + include/cute/arch/copy_sm50.hpp | 30 +- include/cute/arch/copy_sm90.hpp | 15 +- include/cute/arch/copy_sm90_desc.hpp | 53 +- include/cute/arch/copy_sm90_tma.hpp | 95 +- include/cute/arch/mma.hpp | 6 +- include/cute/arch/mma_sm90.hpp | 3825 +- include/cute/arch/mma_sm90_desc.hpp | 7 +- include/cute/arch/mma_sm90_gmma.hpp | 3420 +- include/cute/arch/mma_sm90_gmma_sparse.hpp | 53789 ++++++++++++++++ include/cute/arch/util.hpp | 23 +- include/cute/atom/copy_atom.hpp | 95 +- include/cute/atom/copy_traits_sm50.hpp | 19 +- include/cute/atom/copy_traits_sm90_im2col.hpp | 12 +- include/cute/atom/copy_traits_sm90_tma.hpp | 29 +- .../atom/copy_traits_sm90_tma_swizzle.hpp | 22 + include/cute/atom/mma_atom.hpp | 215 +- include/cute/atom/mma_traits.hpp | 113 +- include/cute/atom/mma_traits_sm90.hpp | 18 +- include/cute/atom/mma_traits_sm90_gmma.hpp | 4320 +- .../cute/atom/mma_traits_sm90_gmma_sparse.hpp | 16915 +++++ include/cute/config.hpp | 13 - include/cute/container/alignment.hpp | 20 +- include/cute/container/array_aligned.hpp | 4 +- include/cute/container/array_subbyte.hpp | 14 + include/cute/container/bit_field.hpp | 4 +- include/cute/container/cuda_types.hpp | 8 +- include/cute/container/tuple.hpp | 13 +- include/cute/container/type_list.hpp | 3 +- include/cute/int_tuple.hpp | 166 +- include/cute/layout.hpp | 160 +- include/cute/layout_composed.hpp | 6 +- include/cute/numeric/arithmetic_tuple.hpp | 12 +- include/cute/numeric/complex.hpp | 6 +- include/cute/numeric/int.hpp | 14 +- include/cute/numeric/integral_constant.hpp | 25 +- include/cute/numeric/integral_ratio.hpp | 9 +- include/cute/numeric/math.hpp | 18 +- include/cute/numeric/numeric_types.hpp | 72 +- include/cute/numeric/real.hpp | 18 + include/cute/pointer.hpp | 29 +- include/cute/pointer_base.hpp | 7 +- include/cute/pointer_flagged.hpp | 63 +- include/cute/pointer_sparse.hpp | 172 + include/cute/pointer_swizzle.hpp | 20 +- include/cute/stride.hpp | 131 +- include/cute/swizzle.hpp | 21 +- include/cute/swizzle_layout.hpp | 42 +- include/cute/tensor.hpp | 3 + include/cute/tensor_impl.hpp | 82 +- include/cute/tensor_predicate.hpp | 5 +- include/cute/tensor_zip.hpp | 243 + include/cute/underscore.hpp | 9 +- include/cute/util/print.hpp | 34 +- include/cute/util/type_traits.hpp | 11 +- include/cutlass/arch/barrier.h | 47 + include/cutlass/arch/config.h | 81 + .../cutlass/arch/grid_dependency_control.h | 84 + include/cutlass/arch/memory_sm80.h | 9 + include/cutlass/arch/mma_sm90.h | 25 +- include/cutlass/arch/reg_reconfig.h | 6 +- include/cutlass/arch/synclog.hpp | 1324 + include/cutlass/array.h | 215 +- include/cutlass/bfloat16.h | 155 +- include/cutlass/cluster_launch.hpp | 1 + ..._implicit_gemm_gmma_ss_warpspecialized.hpp | 184 +- include/cutlass/conv/convnd_problem_shape.hpp | 78 +- include/cutlass/conv/detail.hpp | 137 + .../conv/device/conv_universal_adapter.hpp | 17 +- .../cutlass/conv/device/direct_convolution.h | 2 + .../conv/device/implicit_gemm_convolution.h | 19 +- .../device/implicit_gemm_convolution_fusion.h | 1 + include/cutlass/conv/dispatch_policy.hpp | 6 +- .../cutlass/conv/kernel/conv_universal.hpp | 2 + ...sm90_implicit_gemm_tma_warpspecialized.hpp | 369 +- include/cutlass/cuda_host_adapter.hpp | 2 + include/cutlass/cutlass.h | 1 + include/cutlass/detail/collective.hpp | 1 - include/cutlass/detail/layout.hpp | 21 +- include/cutlass/detail/mma.hpp | 5 + include/cutlass/device_kernel.h | 12 + .../collective/builders/sm90_builder.inl | 36 +- .../collective/collective_builder.hpp | 5 +- .../collective/collective_epilogue.hpp | 8 + .../cutlass/epilogue/collective/detail.hpp | 11 +- .../collective/sm70_epilogue_vectorized.hpp | 366 +- .../sm70_epilogue_vectorized_array.hpp | 412 + ...m90_epilogue_array_tma_warpspecialized.hpp | 86 +- .../sm90_epilogue_tma_warpspecialized.hpp | 59 +- ...e_tma_warpspecialized_bias_elementwise.hpp | 9 +- include/cutlass/epilogue/dispatch_policy.hpp | 3 +- .../cutlass/epilogue/fusion/operations.hpp | 66 +- .../sm90_callbacks_tma_warpspecialized.hpp | 343 +- ...90_visitor_compute_tma_warpspecialized.hpp | 31 +- .../sm90_visitor_load_tma_warpspecialized.hpp | 470 +- ...sm90_visitor_store_tma_warpspecialized.hpp | 428 +- .../sm90_visitor_tma_warpspecialized.hpp | 12 +- .../fusion/sm90_visitor_topk_softmax.hpp | 759 + include/cutlass/epilogue/thread/activation.h | 63 +- .../linear_combination_bias_elementwise.h | 119 +- .../threadblock/default_epilogue_tensor_op.h | 38 + include/cutlass/float8.h | 12 + include/cutlass/functional.h | 78 +- .../gemm/collective/builders/sm90_common.inl | 73 +- .../collective/builders/sm90_gmma_builder.inl | 67 +- .../builders/sm90_sparse_config.inl | 268 + .../builders/sm90_sparse_gmma_builder.inl | 388 + .../gemm/collective/collective_builder.hpp | 1 + .../gemm/collective/collective_mma.hpp | 1 + ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 5 +- ...mma_multistage_gmma_rs_warpspecialized.hpp | 2 +- ...mma_multistage_gmma_ss_warpspecialized.hpp | 2 +- .../sm90_mma_tma_gmma_rs_warpspecialized.hpp | 2 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 719 +- .../sm90_mma_tma_gmma_ss_warpspecialized.hpp | 2 +- ...90_mma_tma_gmma_ss_warpspecialized_fp8.hpp | 2 +- ...sparse_mma_tma_gmma_ss_warpspecialized.hpp | 724 + include/cutlass/gemm/device/base_grouped.h | 1 + .../gemm/device/default_gemm_configuration.h | 90 +- include/cutlass/gemm/device/ell_gemm.h | 1 + include/cutlass/gemm/device/gemm.h | 1 + include/cutlass/gemm/device/gemm_array.h | 1 + include/cutlass/gemm/device/gemm_batched.h | 1 + include/cutlass/gemm/device/gemm_complex.h | 1 + include/cutlass/gemm/device/gemm_sparse.h | 1 + .../gemm/device/gemm_sparse_with_absmax.h | 1 + .../gemm/device/gemm_splitk_parallel.h | 1 + .../gemm/device/gemm_universal_adapter.h | 60 +- .../cutlass/gemm/device/gemm_universal_base.h | 2 + include/cutlass/gemm/device/gemv.h | 1 + include/cutlass/gemm/device/rank_2k.h | 1 + include/cutlass/gemm/device/rank_k.h | 1 + include/cutlass/gemm/device/symm.h | 1 + include/cutlass/gemm/device/trmm.h | 1 + include/cutlass/gemm/dispatch_policy.hpp | 30 +- include/cutlass/gemm/kernel/sm70_gemm.hpp | 2 + ..._array_tma_warpspecialized_cooperative.hpp | 55 +- ...emm_array_tma_warpspecialized_pingpong.hpp | 31 +- include/cutlass/gemm/kernel/sm90_gemm_tma.hpp | 2 + .../kernel/sm90_gemm_tma_warpspecialized.hpp | 162 +- ...0_gemm_tma_warpspecialized_cooperative.hpp | 104 +- ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 44 +- .../gemm/kernel/sm90_gemm_warpspecialized.hpp | 4 +- .../sm90_gemm_warpspecialized_cooperative.hpp | 17 +- .../sm90_gemm_warpspecialized_pingpong.hpp | 7 +- .../gemm/kernel/sm90_tile_scheduler.hpp | 18 +- .../gemm/kernel/sm90_tile_scheduler_group.hpp | 17 +- .../kernel/sm90_tile_scheduler_stream_k.hpp | 99 +- .../gemm/kernel/static_tile_scheduler.hpp | 112 +- .../cutlass/gemm/kernel/tile_scheduler.hpp | 40 +- .../gemm/kernel/tile_scheduler_params.h | 9 +- .../gemm/warp/default_mma_tensor_op_sm80.h | 35 +- .../gemm/warp/mma_mixed_input_tensor_op.h | 12 +- include/cutlass/gemm/warp/mma_tensor_op.h | 2 +- include/cutlass/kernel_launch.h | 68 + include/cutlass/numeric_conversion.h | 604 +- include/cutlass/pipeline/sm90_pipeline.hpp | 24 +- include/cutlass/platform/platform.h | 1 - .../cutlass/reduction/device/reduce_split_k.h | 27 +- .../device/tensor_reduce_affine_contiguous.h | 1 + .../device/tensor_reduce_affine_strided.h | 1 + include/cutlass/subbyte_reference.h | 6 +- include/cutlass/tensor_ref.h | 3 +- .../device/transform_universal_adapter.hpp | 167 +- .../kernel/filter_format_transformer.hpp | 28 +- .../kernel/sm90_sparse_gemm_compressor.hpp | 578 + .../kernel/sparse_gemm_compressor.hpp | 284 + include/cutlass/uint128.h | 7 +- include/cutlass/version.h | 4 +- media/docs/dependent_kernel_launch.md | 32 + media/docs/profiler.md | 89 +- media/docs/programming_guidelines.md | 7 +- media/docs/utilities.md | 48 + pyproject.toml | 2 +- python/cutlass/__init__.py | 2 +- python/cutlass/backend/epilogue.py | 6 +- .../cutlass/backend/evt/backend/sm90_nodes.py | 4 +- python/cutlass/emit/pytorch.py | 4 +- python/cutlass_library/conv3x_emitter.py | 3 + python/cutlass_library/gemm_operation.py | 20 +- python/cutlass_library/generator.py | 1683 +- python/cutlass_library/library.py | 2 + python/cutlass_library/manifest.py | 17 +- python/cutlass_library/sm90_shapes.py | 212 + python/cutlass_library/sm90_utils.py | 601 + python/setup_library.py | 2 +- python/setup_pycute.py | 2 +- test/self_contained_includes/CMakeLists.txt | 95 + test/unit/CMakeLists.txt | 2 +- test/unit/conv/cache_testbed_output.h | 2 +- test/unit/conv/device/conv2d_testbed.h | 14 +- test/unit/conv/device/conv3d_testbed.h | 8 +- .../conv/device_3x/conv_problem_sizes.hpp | 145 +- ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 18 +- ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 19 + ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 22 +- ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 17 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 17 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 16 + ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 18 +- ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 16 + ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 16 + test/unit/conv/device_3x/testbed_conv.hpp | 15 +- ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 17 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f16.cu | 16 + ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 16 + test/unit/core/fast_numeric_conversion.cu | 3 + test/unit/core/functional.cu | 26 +- test/unit/core/numeric_conversion.cu | 126 + test/unit/cute/ampere/cooperative_copy.cu | 2 + test/unit/cute/ampere/cooperative_gemm.cu | 2 + test/unit/cute/ampere/tiled_cp_async.cu | 1 + test/unit/cute/core/CMakeLists.txt | 3 +- test/unit/cute/core/composition.cpp | 9 +- test/unit/cute/core/domain_distribute.cpp | 7 +- test/unit/cute/core/int_tuple.cpp | 178 +- test/unit/cute/core/inverse_left.cpp | 6 +- test/unit/cute/core/inverse_right.cpp | 5 +- test/unit/cute/core/math.cpp | 10 + test/unit/cute/core/swizzle_layout.cpp | 116 + test/unit/cute/hopper/cooperative_gemm.cu | 2 + test/unit/cute/layout/layout_operator.cu | 3 + test/unit/cute/volta/cooperative_gemm.cu | 2 + test/unit/gemm/device/CMakeLists.txt | 28 +- test/unit/gemm/device/gemm_testbed_3x.hpp | 805 +- test/unit/gemm/device/gemm_testbed_3x_evt.hpp | 1060 +- .../gemm/device/gemm_testbed_3x_ptr_array.hpp | 51 +- .../gemm_testbed_3x_tensor_broadcast.hpp | 8 +- ...8n_bf16t_mixed_input_tensor_op_f32_sm80.cu | 16 +- ...6n_bf16t_mixed_input_tensor_op_f32_sm80.cu | 16 +- test/unit/gemm/device/sm90_evt_operations.hpp | 248 +- ...er_warpspecialized_cooperative_aux_load.cu | 54 + ...r_warpspecialized_cooperative_aux_store.cu | 64 + ...ster_warpspecialized_cooperative_reduce.cu | 6 +- ...cluster_warpspecialized_pingpong_reduce.cu | 6 +- ...mm_f16_f16_f16_tensor_op_f32_group_gemm.cu | 126 + ...6_f16_tensor_op_f32_group_gemm_pingpong.cu | 65 +- ...16_f16_tensor_op_f32_ptr_array_pingpong.cu | 2 +- ..._rs_cluster_warpspecialized_cooperative.cu | 12 +- .../sm90_gemm_f32_f32_f32_tensor_op_f32.cu | 6 +- ..._rs_cluster_warpspecialized_cooperative.cu | 12 +- .../sm90_gemm_f8_f8_f8_tensor_op_fp32_evt.cu | 58 + .../device/sm90_gemm_stream_k_scheduler.cu | 2 +- ...0_sparse_gemm_f16_f16_f32_tensor_op_f32.cu | 255 + ...m90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu | 216 + ...m90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu | 216 + ...sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu | 216 + test/unit/gemm/device/testbed.h | 96 +- .../gemm/device/testbed_gemm_with_broadcast.h | 2 +- .../gemm/device/testbed_gemm_with_reduction.h | 2 +- test/unit/gemm/device/testbed_universal.h | 1 + test/unit/gemm/threadblock/mma_multistage.cu | 14 - test/unit/gemm/warp/gemm_sm80.cu | 209 - test/unit/transform/device/CMakeLists.txt | 58 + .../device/sm90_sparse_gemm_compressor_f16.cu | 95 + .../device/sm90_sparse_gemm_compressor_f32.cu | 95 + .../device/sm90_sparse_gemm_compressor_f8.cu | 95 + .../sm90_sparse_gemm_compressor_legacy.hpp | 480 + .../device/testbed_sparse_gemm_compressor.hpp | 876 + tools/library/CMakeLists.txt | 4 +- .../include/cutlass/library/arch_mappings.h | 6 + .../library/include/cutlass/library/library.h | 12 +- tools/library/src/conv_operation_3x.hpp | 4 +- tools/library/src/gemm_operation_3x.hpp | 12 +- .../src/reference/gemm_fp_mixed_input.cu | 4 +- tools/library/src/reference/gemm_fp_other.cu | 8 + .../src/reference/gemm_int_mixed_input.cu | 4 +- tools/library/src/reference/gemm_s8_s8_s32.cu | 146 + ...mm_int8_canonical.cu => gemm_u8_u8_s32.cu} | 86 +- .../initialize_reference_operations.cu | 7 +- .../library/src/sparse_gemm_operation_3x.hpp | 445 + tools/library/src/util.cu | 1 + .../include/cutlass/profiler/cublas_helpers.h | 122 +- .../cutlass/profiler/cutlass_profiler.h | 7 +- .../cutlass/profiler/device_allocation.h | 53 +- .../include/cutlass/profiler/device_context.h | 46 +- .../include/cutlass/profiler/options.h | 31 +- .../profiler/src/conv2d_operation_profiler.cu | 229 +- .../profiler/src/conv3d_operation_profiler.cu | 208 +- tools/profiler/src/cublas_helpers.cu | 275 +- tools/profiler/src/cutlass_profiler.cu | 13 - tools/profiler/src/device_allocation.cu | 408 +- tools/profiler/src/device_context.cu | 79 +- tools/profiler/src/gemm_operation_profiler.cu | 113 +- tools/profiler/src/operation_profiler.cu | 12 +- tools/profiler/src/options.cu | 180 +- tools/profiler/src/performance_report.cpp | 42 +- .../src/rank_2k_operation_profiler.cu | 114 +- .../profiler/src/rank_k_operation_profiler.cu | 113 +- .../src/sparse_gemm_operation_profiler.cu | 123 +- tools/profiler/src/symm_operation_profiler.cu | 127 +- tools/profiler/src/trmm_operation_profiler.cu | 128 +- .../util/include/cutlass/util/device_memory.h | 43 +- .../util/include/cutlass/util/distribution.h | 12 +- tools/util/include/cutlass/util/host_tensor.h | 11 +- .../include/cutlass/util/packed_stride.hpp | 2 + .../util/reference/device/tensor_fill.h | 127 +- .../cutlass/util/reference/host/conv.hpp | 3 +- .../cutlass/util/reference/host/gett.hpp | 6 +- .../cutlass/util/reference/host/tensor_fill.h | 243 +- 354 files changed, 105914 insertions(+), 8174 deletions(-) create mode 100644 examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu create mode 100644 examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp create mode 100644 examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu create mode 100644 examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt create mode 100644 examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu create mode 100644 examples/62_hopper_sparse_gemm/CMakeLists.txt create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/README.md create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp create mode 100644 examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp create mode 100644 include/cute/arch/config.hpp create mode 100644 include/cute/arch/mma_sm90_gmma_sparse.hpp create mode 100644 include/cute/atom/mma_traits_sm90_gmma_sparse.hpp create mode 100644 include/cute/pointer_sparse.hpp create mode 100644 include/cute/tensor_zip.hpp create mode 100644 include/cutlass/arch/config.h create mode 100644 include/cutlass/arch/grid_dependency_control.h create mode 100644 include/cutlass/arch/synclog.hpp create mode 100644 include/cutlass/conv/detail.hpp create mode 100644 include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp create mode 100644 include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp create mode 100644 include/cutlass/gemm/collective/builders/sm90_sparse_config.inl create mode 100644 include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl create mode 100644 include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp create mode 100644 include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp create mode 100644 include/cutlass/transform/kernel/sparse_gemm_compressor.hpp create mode 100644 media/docs/dependent_kernel_launch.md create mode 100644 python/cutlass_library/sm90_shapes.py create mode 100644 python/cutlass_library/sm90_utils.py create mode 100644 test/unit/cute/core/swizzle_layout.cpp create mode 100644 test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu create mode 100644 test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu create mode 100644 test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu create mode 100644 test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu create mode 100644 test/unit/transform/device/CMakeLists.txt create mode 100644 test/unit/transform/device/sm90_sparse_gemm_compressor_f16.cu create mode 100644 test/unit/transform/device/sm90_sparse_gemm_compressor_f32.cu create mode 100644 test/unit/transform/device/sm90_sparse_gemm_compressor_f8.cu create mode 100644 test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp create mode 100644 test/unit/transform/device/testbed_sparse_gemm_compressor.hpp create mode 100644 tools/library/src/reference/gemm_s8_s8_s32.cu rename tools/library/src/reference/{gemm_int8_canonical.cu => gemm_u8_u8_s32.cu} (65%) create mode 100644 tools/library/src/sparse_gemm_operation_3x.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index c784107be9..c98cdb515f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # NVIDIA CUTLASS Changelog +## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03) + +- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu). + + [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu) + + [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu) + + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) +- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. +- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. +- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). +- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). +- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. +- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). +- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. +- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). +- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) +- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). + Various improvements and fixed from the community and CUTLASS team. Thanks to everyone who submitted PRs! + ## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25) - [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7419bdf5e5..e61b66a877 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -134,7 +134,6 @@ set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUT set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests") set(CUTLASS_USE_SYSTEM_GOOGLETEST OFF CACHE BOOL "Use system/external installation of GTest") - set(CUTLASS_USE_PACKED_TUPLE ON CACHE BOOL "If ON, make cute::tuple be new standard-layout tuple type; if OFF, use the original cute::tuple implementation that is _not_ standard-layout.") if (CUTLASS_USE_PACKED_TUPLE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_USE_PACKED_TUPLE=1) @@ -234,7 +233,6 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.") -set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API.") ################################################################################ # @@ -271,6 +269,7 @@ set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of opera set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.") set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.") set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).") +set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.") ################################################################################ @@ -318,6 +317,8 @@ if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) endif() +set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace") + # # NOTE: running with asan and CUDA requires the following environment variable: # @@ -345,6 +346,10 @@ if(CUTLASS_NVCC_EMBED_PTX) list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all) endif() +if (CUTLASS_SKIP_REDUCTION_INIT) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SKIP_REDUCTION_INIT=1) +endif() + if (CUTLASS_ENABLE_TENSOR_CORE_MMA) list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) endif() @@ -354,6 +359,18 @@ if (CUTLASS_PROFILER_DISABLE_REFERENCE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_PROFILER_DISABLE_REFERENCE=1) endif() +if (CUTLASS_ENABLE_GDC_FOR_SM90) + message(STATUS "Grid Dependency Control (GDC) is enabled for SM90 kernels (required for programmatic dependent launches).") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1) +endif() + +set(CUTLASS_ENABLE_SYNCLOG OFF CACHE BOOL "Enable synchronization event logging for race condition debugging. WARNING: This redefines __syncthreads() and __syncwarp() in all downstream code!") + +if (CUTLASS_ENABLE_SYNCLOG) + set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) + string(APPEND CMAKE_CXX_FLAGS " -DCUTLASS_ENABLE_SYNCLOG=1") + string(APPEND CMAKE_CUDA_FLAGS " -DCUTLASS_ENABLE_SYNCLOG=1") +endif() @@ -880,12 +897,27 @@ function(cutlass_add_executable_tests NAME TARGET) set(TEST_GROUP_NAME ${NAME}) + # To run the tests from an install package with tests enabled, we need to generate test files + # that don't rely on the current directory structure in build. + + set(TEST_NAME c${NAME}) + set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) + file(MAKE_DIRECTORY ${TEST_GEN_DIR}) + + set(TEST_EXE_PATH $) + set(TEST_USE_EXTENDED_FORMAT ON) + configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY) + + set(TEST_EXE_PATH $) + set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format. + configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY) + foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS) if (CMD_COUNT GREATER 1) - string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TEST_NAME) + string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TESTCASE_NAME) else() - string(TOLOWER "${NAME}" TEST_NAME) + string(TOLOWER "${NAME}" TESTCASE_NAME) endif() # The following rigmarole is needed to deal with spaces and possible quotes in @@ -899,7 +931,7 @@ function(cutlass_add_executable_tests NAME TARGET) separate_arguments(TEST_COMMAND_OPTIONS) add_custom_target( - ${TEST_NAME} + ${TESTCASE_NAME} COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ ${TEST_COMMAND_OPTIONS} DEPENDS @@ -907,34 +939,20 @@ function(cutlass_add_executable_tests NAME TARGET) ) if (CMD_COUNT GREATER 1) - add_dependencies(${NAME} ${TEST_NAME}) + add_dependencies(${NAME} ${TESTCASE_NAME}) endif() foreach(DEPENDEE ${__DEPENDEES}) - add_dependencies(${DEPENDEE} ${TEST_NAME}) + add_dependencies(${DEPENDEE} ${TESTCASE_NAME}) endforeach() - set(TEST_NAME c${TEST_NAME}) + set(TESTCASE_NAME c${TESTCASE_NAME}) string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY) - string(APPEND _INLINE_PER_TEST_CODE "${_TEST_CODE}") + file(APPEND "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" "${_TEST_CODE}") + file(APPEND "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" "${_TEST_CODE}") endforeach() - # To run the tests from an install package with tests enabled, we need to generate test files - # that don't rely on the current directory structure in build. - - set(TEST_NAME c${NAME}) - set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) - file(MAKE_DIRECTORY ${TEST_GEN_DIR}) - - set(TEST_EXE_PATH $) - set(TEST_USE_EXTENDED_FORMAT ON) - configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY) - - set(TEST_EXE_PATH $) - set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format. - configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY) - # The following line imports the tests for immediate run via `make test`. include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake) diff --git a/PUBLICATIONS.md b/PUBLICATIONS.md index 04d4cd0a14..b7425f251d 100644 --- a/PUBLICATIONS.md +++ b/PUBLICATIONS.md @@ -26,6 +26,8 @@ - ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, Feburary 2023. +- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, Feburary 2023. + - ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023. ## 2022 diff --git a/README.md b/README.md index 1426e8a42e..e61335f240 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.5.1 +# CUTLASS 3.6.0 -_CUTLASS 3.5.1 - July 2024_ +_CUTLASS 3.6.0 - October 2024_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -42,48 +42,26 @@ and improves code composability and readability. More documentation specific to In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -# What's New in CUTLASS 3.5 - -CUTLASS 3.5.1 is an update to CUTLASS adding: - -- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu). -- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) -- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and -[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). -- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). -- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inference. -- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. -- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). -- Support for residual add (beta != 0) in convolution kernels. -- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. -- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). -- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). -- Better support for MSVC as a host compiler. -- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. -- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. -- NOTICE: - + Upcoming CUTLASS 3.6 release will include a breaking refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` API to bring it in line with `gemm::GemmUniversal`. After this, the 3.x convolution API will no longer be considered as a beta API. - + Upcoming CUTLASS 3.6 release will include a breaking refactor to the Hopper TMA pointer array batched epilogue in order to support grouped GEMMs. - -CUTLASS 3.5.0 is an update to CUTLASS adding: - -- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp). - + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). - + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp). - + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms. - + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. - + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! -- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. -- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x. - + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. - + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. -- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. -- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). -- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). -- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. -- Fixes to greatly reduce build warnings. -- Updates and bugfixes from the community (thanks!) -- CUTLASS 3.5.1 is a minor update to CUTLASS containing small bug fixes and improvements, including fixes for FlashAttention-2 builds. +# What's New in CUTLASS 3.6 + +CUTLASS 3.6.0 is an update to CUTLASS adding: + +- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu). + + [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu) + + [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu) + + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) +- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. +- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. +- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). +- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). +- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. +- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). +- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. +- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). +- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) +- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). Minimum requirements: @@ -163,7 +141,7 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). -The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12 or 11.8, the kernel is expected to fail with a runtime error. +The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CUDA Toolkit 12 or 11.8, the kernel is expected to fail with a runtime error. ``` cmake .. -DCUTLASS_NVCC_ARCHS="90a" @@ -191,6 +169,8 @@ CUTLASS is described in the following documents and the accompanying - [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory - [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application - [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilate rapid development +- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent +kernels in the same stream, and how it is used in CUTLASS. # Resources We have also described the structure of an efficient GEMM in our talk at the diff --git a/cmake/CTestTestfile.configure.cmake b/cmake/CTestTestfile.configure.cmake index 94394a5000..611b3d181f 100644 --- a/cmake/CTestTestfile.configure.cmake +++ b/cmake/CTestTestfile.configure.cmake @@ -50,5 +50,3 @@ if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) else() set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@) endif() - -@_INLINE_PER_TEST_CODE@ diff --git a/cmake/CTestTestfile.test.configure.cmake b/cmake/CTestTestfile.test.configure.cmake index fa2ceeb9bd..31dba54498 100644 --- a/cmake/CTestTestfile.test.configure.cmake +++ b/cmake/CTestTestfile.test.configure.cmake @@ -30,14 +30,14 @@ if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT) # The longform/extended format allows generator expressions to be # expanded property and is useful in contexts where the files need # to be immediately included into being-processed cmake code. - add_test(NAME @TEST_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) + add_test(NAME @TESTCASE_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) else() - add_test(@TEST_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) + add_test(@TESTCASE_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) endif() if (TEST_EXE_WORKING_DIRECTORY) - set_tests_properties(@TEST_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}") + set_tests_properties(@TESTCASE_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}") endif() -set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) +set_tests_properties(@TESTCASE_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 0350fb2dd1..d220cfadc2 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -34,9 +34,10 @@ if(GOOGLETEST_DIR) set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") endif() +set(GTEST_REPOSITORY "https://github.com/google/googletest.git" CACHE STRING "GoogleTest repo to fetch") FetchContent_Declare( googletest - GIT_REPOSITORY https://github.com/google/googletest.git + GIT_REPOSITORY ${GTEST_REPOSITORY} GIT_TAG v1.14.0 ) diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 27156ea02d..731e37b4d9 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -42,7 +42,8 @@ #include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/gemm/device/gemm_complex.h" - +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_size.h" #include "cutlass/util/command_line.h" #include "cutlass/util/host_tensor.h" @@ -56,6 +57,7 @@ #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/error_metrics.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes #include "cutlass/layout/matrix.h" #include "cutlass/epilogue/thread/linear_combination.h" @@ -657,7 +659,9 @@ struct Testbed { } int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; - int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n(); + int64_t bytes = cutlass::bits_to_bytes( + (cutlass::sizeof_bits::value * 2 + cutlass::sizeof_bits::value) * + options.problem_size.m() * options.problem_size.n()); double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9); double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30); diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu index f26f4da37d..164c785e01 100644 --- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -303,14 +303,14 @@ bool initialize_block( int bits_input = cutlass::sizeof_bits::value; if (bits_input == 1) { - scope_max = 2; - scope_min = 0; + scope_max = Element(2); + scope_min = Element(0); } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; + scope_max = Element(2); + scope_min = Element(-2); } else { - scope_max = 8; - scope_min = -8; + scope_max = Element(8); + scope_min = Element(-8); } cutlass::reference::device::BlockFillRandomUniform( diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp index 57053b0f9a..c71109aa79 100644 --- a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -111,7 +111,7 @@ class GemmGather EpilogueTensorStorage epilogue; } tensors; - struct PipelineStorage : cute::aligned_struct<16> { + struct PipelineStorage : cute::aligned_struct<16, _2> { using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; diff --git a/examples/53_hopper_gemm_permute/permute_traits.hpp b/examples/53_hopper_gemm_permute/permute_traits.hpp index 96fcc64cf9..4c5baccac5 100644 --- a/examples/53_hopper_gemm_permute/permute_traits.hpp +++ b/examples/53_hopper_gemm_permute/permute_traits.hpp @@ -50,7 +50,7 @@ struct PermuteTraits {}; using X = Underscore; // Reshape a rank-2 shape into a multidimensional shape. -// Input: +// Input: // shape = (A, B, ...) // target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...) // Output: @@ -76,12 +76,12 @@ reshape(Shape const& shape, TargetShape const& target_shape) // - sub-modes corresponding to the implied multidimensional shape of the source tensor // - strides accounting for the permutation operation being performed template -constexpr auto +constexpr auto make_permute_layout(Layout const& layout) { static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); if constexpr (Transpose) { // Deal with tensor B by transposing appropriately before and after computing the permute layout. - // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. return select<1,0,2>(make_permute_layout(select<1,0,2>(layout))); } else { @@ -129,23 +129,24 @@ inverse(Permutation const & perm) { template using inverse_t = decltype(inverse(T{})); -// Given a rank-2 layout of tensor that is assumed to have been permuted, +// Given a rank-2 layout of tensor that is assumed to have been permuted, // compute the original rank-2 layout of the tensor prior to the permutation. -// This is needed to form the correct input to the standalone permutation kernel. +// This is needed to form the correct input to the standalone permutation kernel. template -constexpr auto +constexpr auto make_original_layout(Layout const& layout) { static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); if constexpr (Transpose) { // Deal with tensor B by transposing appropriately before and after computing the permute layout. - // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. return select<1,0,2>(make_original_layout(select<1,0,2>(layout))); } else { using ShapeProfile = typename PermuteTraits::ShapeProfile; + auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{})); using IndexOrder = typename PermuteTraits::IndexOrder; + auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get(re_shape); }); using OrigOrder = conditional_t(), seq<0,1,2>, seq<1,0,2>>; - auto orig_shape = select(flatten(reshape(layout.shape(), ShapeProfile{})), IndexOrder{}); // print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n"); // print("Original shape: "); print(orig_shape); print("\n"); return make_ordered_layout(product_each(orig_shape), OrigOrder{}); @@ -202,7 +203,7 @@ struct PermuteTraits> }; template -struct PermuteTraits> +struct PermuteTraits> { static constexpr bool kBatched = true; using ShapeProfile = Shape>, Shape, Shape>; @@ -222,7 +223,7 @@ struct PermuteTraits> }; template -struct PermuteTraits> +struct PermuteTraits> { static constexpr bool kBatched = true; using ShapeProfile = Shape, Shape>, Shape>; diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu new file mode 100644 index 0000000000..138f7a0402 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -0,0 +1,701 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example shows how to perform INT4 x FP8 GEMM and scale up the INT4 weight during dequantization. It uses a look-up table to avoid the multiplications + between INT4 and FP8. To trigger this method, use cutlass::Array as the scale type in the collective's arguments. + + However, this algorithm requires changes to the encoding of INT4 weights and scale factors. These changes must happen before launching the GEMM. See the helper functions + `unify_quant_encoding`, `initialize_packed_scale`, and header `fp8_packed_scale.hpp` for details. + + In a nutshell, the positive values of INT4 weights need to be encoded in the same way as negative values except for the sign bit. For each scale factor, + 8 negative results (-8 x scale, -7 x scale, ... -1 x scale) are packed together, forming a cutlass::Array value. + + The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap + A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue, + as illustrated in this example. + + Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). + + Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. + + If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k]. + + The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size + equal to the gemm problem K. + + Limitations: + 1) Only supports INT4 x { FP8, INT8, UINT8 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported. + 2) The INT4 weights and scale factors have additional encoding requirements. + 3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 4) The scales must have the same layout and groupsize. + 5) The groupsize must be greater or equal to the tile shape k. + 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. + We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + + Examples: + + Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0) + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0 + + Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire + matrix (group size is the same as the gemm k dimension). + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "helper.h" +#include "unfused_weight_dequantize.hpp" +#include "packed_scale.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using ElementScale = MmaType; +using ElementZero = ElementScale; // only for verify +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple >, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; +using StrideC = typename GemmKernelScaleOnly::StrideC; +using StrideD = typename GemmKernelScaleOnly::StrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideC_ref stride_C_ref; +StrideD stride_D; +StrideD_ref stride_D_ref; +uint64_t seed; + +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_modified; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation> block_scale_packed; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 10; + int m = 5120, n = 4096, k = 4096; + int g = 128; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("g", g); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_warp_specialized_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + float scope_min = float(cutlass::platform::numeric_limits::lowest()); + float scope_max = float(cutlass::platform::numeric_limits::max()); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). +bool unify_quant_encoding( + cutlass::DeviceAllocation const& block_in, + cutlass::DeviceAllocation& block_out) { + + using StorageType = cutlass::int4b_t::Storage; + + if (block_in.size() != block_out.size()) { + std::cerr << "block_in and block_out must have same size.\n"; + return false; + } + constexpr int pack = sizeof_bits_v / 4; + std::vector data(block_in.size() / pack); + cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack); + + for (auto&& d : data) { + StorageType out = 0; + StorageType mask = 0x0f; + for (int i = 0; i < pack; ++i) { + cutlass::int4b_t curr; + curr.storage = (d >> (i * 4)) & 0x0f; + switch (curr) { + case 1: curr.storage = StorageType(0b0111); break; // 2's complement + case 2: curr.storage = StorageType(0b0110); break; // 2's complement + case 3: curr.storage = StorageType(0b0101); break; // 2's complement + case 4: curr.storage = StorageType(0b0100); break; // 2's complement + case 5: curr.storage = StorageType(0b0011); break; // 2's complement + case 6: curr.storage = StorageType(0b0010); break; // 2's complement + case 7: curr.storage = StorageType(0b0001); break; // 2's complement + default: break; + } + out |= (curr.storage << (4 * i)) & mask; + mask <<= 4; + } + d = out; + } + + cutlass::device_memory::copy_to_device((uint8_t*)block_out.get(), data.data(), block_out.size() / 2); + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + Options const& options) { + + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + float const max_dequant_val = 4.f; + float const min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + return true; +} + +bool initialize_packed_scale( + cutlass::DeviceAllocation const& block_in, + cutlass::DeviceAllocation > & block_out) { + + std::vector data_in(block_in.size()); + std::vector > data_out(block_in.size()); + try { + block_in.copy_to_host(data_in.data()); + } catch (cutlass::cuda_exception const& e) + { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + for (size_t i = 0; i < block_in.size(); ++i) + { + cutlass::packed_scale_t tmp(data_in[i]); + data_out[i] = reinterpret_cast const&>(tmp); + // std::cout << data_in[i] << ":" << std::hex << static_cast(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast((-data_in[i]).storage) << std::endl; + } + try { + block_out.copy_from_host(data_out.data()); + } catch (cutlass::cuda_exception const& e) + { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + Options const& options) { + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options const& options) { + + auto shape_b = cute::make_shape(options.n, options.k, options.l); + int const scale_k = (options.k + options.g - 1) / options.g; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); + // Reverse stride here due to swap and transpose + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_modified.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); + + block_scale.reset(scale_k * options.l * options.n); + block_scale_packed.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); + + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + unify_quant_encoding(block_B, block_B_modified); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_packed_scale(block_scale, block_scale_packed); + initialize_zero(block_zero, options); + + auto layout_B = make_layout(shape_b, stride_B); + + auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); + stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); + stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); + auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +Args args_from_options(Options const& options) +{ +// Swap the A and B tensors, as well as problem shapes here. + + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B_modified.get(), stride_B, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; +} + +bool verify(Options const& options) { + // + // Compute reference output + // + + // In this example, we use the GPU default kernels as a reference (unfused scale). + // This avoids numerical differences due to different accumulation order. + + // Again, due to numerical differences, we must use fast acc here when the mma type is + // FP8 as the fused implementation only supports fast acc at the moment. + constexpr bool IsFP8Input = cute::is_same_v || cute::is_same_v; + using FP8Sched = cute::conditional_t(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using ScheduleRef = cute::conditional_t; + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + ScheduleRef + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (options.g == options.k) { + std::cout << "Running in per-column scale mode." << std::endl; + } else { + std::cout << "Running in group scale mode." << std::endl; + } + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index 28baae260c..8a99cc2754 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -53,14 +53,18 @@ equal to the gemm problem K. Limitations: - 1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}. - 2) The narrow type must always be in K-major format. - 3) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. - 4) The scales and the zeros must have the same layout and groupsize. + 1) The narrow type must always be in K-major format. + 2) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 3) The scales and the zeros must have the same layout and groupsize. + 4) The groupsize must be greater or equal to tile shape k. 5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format. 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + 2) Try avoid using scale or zero mode cause the computations will be the bottleneck. Examples: @@ -94,11 +98,8 @@ #include "cutlass/util/host_tensor.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" #include "helper.h" #include "unfused_weight_dequantize.hpp" @@ -117,8 +118,8 @@ enum GemmMode { ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations ///////////////////////////////////////////////////////////////////////////////////////////////// -using MmaType = cutlass::float_e4m3_t; -using QuantType = cutlass::int4b_t; +using MmaType = cutlass::half_t; +using QuantType = cutlass::float_e4m3_t; constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; // A matrix configuration @@ -154,8 +155,8 @@ using ElementAccumulator = float; // E using ElementCompute = float; // Element type for epilogue computation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_128,_256,cute::Int>; // Threadblock-level tile size -using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; @@ -268,14 +269,14 @@ using StrideS_ref = cutlass::detail::TagToStrideB_t; StrideS stride_S; StrideS_ref stride_S_ref; -cutlass::HostTensor tensor_A; -cutlass::HostTensor tensor_B; -cutlass::HostTensor tensor_B_dq; -cutlass::HostTensor tensor_scale; -cutlass::HostTensor tensor_zero; -cutlass::HostTensor tensor_C; -cutlass::HostTensor tensor_D; -cutlass::HostTensor tensor_ref_D; +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) @@ -290,7 +291,7 @@ struct Options { float alpha = 1.0f; float beta = 0.0f; - int iterations = 1000; + int iterations = 10; int mode = 2; int m = 5120, n = 4096, k = 4096; int g = 128; @@ -368,9 +369,9 @@ struct Result ///////////////////////////////////////////////////////////////////////////////////////////////// /// Helper to initialize a block of device data -template +template bool initialize_tensor( - cutlass::TensorView view, + cutlass::DeviceAllocation& block, uint64_t seed=2023) { double scope_max, scope_min; @@ -393,34 +394,35 @@ bool initialize_tensor( scope_max = 8; scope_min = -8; } - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); return true; } -template +template bool initialize_quant_tensor( - cutlass::TensorView view, + cutlass::DeviceAllocation& block, uint64_t seed=2023) { float scope_min = float(cutlass::platform::numeric_limits::lowest()); float scope_max = float(cutlass::platform::numeric_limits::max()); - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); return true; } -template +template bool initialize_scale( - cutlass::TensorView view, - const Options &options) { + cutlass::DeviceAllocation& block, + Options const& options) { if (options.mode == GemmMode::ConvertOnly) { // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. - cutlass::reference::host::TensorFill(view, Element(1.0f)); + std::vector stage(block.size(), Element(1.0f)); + block.copy_from_host(stage.data()); } else { float elt_max_f = float(cutlass::platform::numeric_limits::max()); @@ -430,32 +432,33 @@ bool initialize_scale( float scope_max(max_dequant_val / elt_max_f); float scope_min(min_dequant_val / elt_max_f); - cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); } return true; } -template +template bool initialize_zero( - cutlass::TensorView view, - const Options &options) { + cutlass::DeviceAllocation& block, + Options const& options) { if (options.mode == GemmMode::ScaleWithZeroPoint) { - cutlass::reference::host::TensorFillRandomUniform( - view, seed, 2.0f, -2.0f); + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(2.0f), Element(-2.0f)); } else { // No bias, so just initialize with 1 so we can use the same kernel to dequantize the data. - cutlass::reference::host::TensorFill(view, Element(0.0f)); + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); } return true; } /// Initialize operands to be used in the GEMM and reference GEMM -void initialize(const Options &options) { +void initialize(Options const& options) { auto shape_b = cute::make_shape(options.n, options.k, options.l); - const int scale_k = (options.k + options.g - 1) / options.g; + int const scale_k = (options.k + options.g - 1) / options.g; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); // Reverse stride here due to swap and transpose @@ -469,27 +472,21 @@ void initialize(const Options &options) { auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); - tensor_A.resize(a_coord); - tensor_B.resize(b_coord); - tensor_B_dq.resize(b_coord); - tensor_C.resize(c_coord); - tensor_D.resize(c_coord); - tensor_ref_D.resize(c_coord); - - tensor_scale.resize({scale_k * options.l, options.n}); - tensor_zero.resize({scale_k * options.l, options.n}); + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); - initialize_tensor(tensor_A.host_view(), seed + 2022); - initialize_quant_tensor(tensor_B.host_view(), seed + 2021); - initialize_tensor(tensor_C.host_view(), seed + 2020); - initialize_scale(tensor_scale.host_view(), options); - initialize_zero(tensor_zero.host_view(), options); + block_scale.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_C.sync_device(); - tensor_scale.sync_device(); - tensor_zero.sync_device(); + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); auto layout_B = make_layout(shape_b, stride_B); @@ -498,37 +495,36 @@ void initialize(const Options &options) { stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); - dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), tensor_zero.device_data(), layout_scale_zero, options.g); - tensor_B_dq.sync_host(); + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); } /// Populates a Gemm::Arguments structure from the given commandline options template -Args args_from_options(const Options &options) +Args args_from_options(Options const& options) { // Swap the A and B tensors, as well as problem shapes here. if (options.mode == GemmMode::ConvertOnly) { return Args { cutlass::gemm::GemmUniversalMode::kGemm, {options.n, options.m, options.k, options.l}, - {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + {block_B.get(), stride_B, block_A.get(), stride_A}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; } else if (options.mode == GemmMode::ScaleOnly) { return Args { cutlass::gemm::GemmUniversalMode::kGemm, {options.n, options.m, options.k, options.l}, - {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; } else if (options.mode == GemmMode::ScaleWithZeroPoint) { return Args { cutlass::gemm::GemmUniversalMode::kGemm, {options.n, options.m, options.k, options.l}, - {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g, tensor_zero.device_data()}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g, block_zero.get()}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; } else { std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl; @@ -542,7 +538,7 @@ bool verify(const Options &options) { // // In this example, we use the GPU default kernels as a reference (unfused scale) - // This is to avoid numerical differences from different accumulation order. + // This avoids numerical differences due to different accumulation order. // Again, due to numerical differences, we must use fast acc here when the mma type is // FP8 as the fused implementation only supports fast acc at the moment. @@ -581,8 +577,8 @@ bool verify(const Options &options) { typename GemmRef::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {options.m, options.n, options.k, options.l}, - {tensor_A.device_data(), stride_A, tensor_B_dq.device_data(), stride_B}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C_ref, tensor_ref_D.device_data(), stride_D_ref} + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} }; // Run the gemm where the scaling is performed outside of the kernel. @@ -594,11 +590,9 @@ bool verify(const Options &options) { CUTLASS_CHECK(gemm_ref.run()); // compare_reference - tensor_D.sync_host(); - tensor_ref_D.sync_host(); - const ElementD epsilon(1e-2f); - const ElementD non_zero_floor(1e-4f); - bool passed = cutlass::reference::host::TensorRelativelyEquals(tensor_ref_D.host_view(), tensor_D.host_view(), epsilon, non_zero_floor); + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); return passed; } diff --git a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt index 5ddfbd2e6e..a9753ed100 100644 --- a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt +++ b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt @@ -55,5 +55,16 @@ cutlass_example_add_executable( TEST_SCALE_ZERO_GROUPED TEST_SCALE_RESIDUE TEST_SCALE_ZERO_RESIDUE - TEST_ALPHA_BETA + # TEST_ALPHA_BETA + ) + +cutlass_example_add_executable( + 55_hopper_int4_fp8_gemm + 55_hopper_int4_fp8_gemm.cu + TEST_COMMAND_OPTIONS + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL + TEST_SCALE_GROUP + TEST_SCALE_RESIDUE + # TEST_ALPHA_BETA ) diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md index 8c393a6b75..07265f0d7e 100644 --- a/examples/55_hopper_mixed_dtype_gemm/README.md +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -11,6 +11,8 @@ This first version only supports mixed type GEMMs using TMA. While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type. +The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now. + We are currently optimizing the following cases: 1. Memory bound cases for all types diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp new file mode 100644 index 0000000000..294d426135 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include "cutlass/float8.h" + +namespace cutlass +{ +template +class packed_scale_t { +public: + static_assert(cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "only 8 bit arithmetic types are supported."); + CUTLASS_HOST_DEVICE + explicit packed_scale_t(T val) { + if constexpr (!cute::is_unsigned_v) { + // Only pack negative values. The positive values are generated in flight in the mainloop. + storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); + storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); + } + else { + storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); + storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); + } + } + CUTLASS_HOST_DEVICE + packed_scale_t() = default; + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + CUTLASS_HOST_DEVICE + bool operator==(packed_scale_t const& rhs) const { + return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; + } + CUTLASS_HOST_DEVICE + bool operator!=(packed_scale_t const& rhs) const { + return !(*this == rhs); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() + rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() - rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() * rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() / rhs.get()); + } + +private: + using Storage = uint32_t; + using Stage = uint8_t; + + Storage storage[2] {}; + + CUTLASS_HOST_DEVICE + static Storage pack4(T c1, T c2, T c3, T c4) { + Storage result = 0; + result |= (static_cast(reinterpret_cast(c4)) << 24); + result |= (static_cast(reinterpret_cast(c3)) << 16); + result |= (static_cast(reinterpret_cast(c2)) << 8); + result |= static_cast(reinterpret_cast(c1)); + return result; + } + CUTLASS_HOST_DEVICE + T get() const { + auto stage = static_cast(storage[0] >> 8); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } + CUTLASS_HOST_DEVICE + T get(int idx) const { + Stage stage; + if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); + else stage = static_cast(storage[1] >> (8 * idx - 32)); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } +}; +} diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu index 5181678ca7..51ce970dbd 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -32,7 +32,7 @@ /*! \file \brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. - This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA + This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA warp-specialized cooperative kernel. The new feature showcased in this example is on-the-fly modification of TMA descriptors to move between batches (represented by l). @@ -547,3 +547,4 @@ int main(int argc, char const **args) { } ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index a26d904dcc..d57e1deea5 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -91,9 +91,9 @@ using namespace cute; using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group -using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand -using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand -using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using ElementC = cutlass::half_t; // Element type for C and D matrix operands #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) diff --git a/examples/59_ampere_gather_scatter_conv/CMakeLists.txt b/examples/59_ampere_gather_scatter_conv/CMakeLists.txt index e7f164003d..ce22cd1f37 100644 --- a/examples/59_ampere_gather_scatter_conv/CMakeLists.txt +++ b/examples/59_ampere_gather_scatter_conv/CMakeLists.txt @@ -26,6 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +if (NOT MSVC) + cutlass_example_add_executable( 59_ampere_gather_scatter_conv ampere_gather_scatter_conv.cu @@ -34,3 +36,5 @@ cutlass_example_add_executable( if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND) target_link_libraries(59_ampere_gather_scatter_conv PRIVATE OpenMP::OpenMP_CXX) endif() + +endif() diff --git a/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu new file mode 100644 index 0000000000..8bb14b4556 --- /dev/null +++ b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu @@ -0,0 +1,534 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM + Top-K + Softmax fusion + + This example illustrates how to use the LinCombTopKSoftmaxCol EVT node to fuse + Top-K and Softmax into the GEMM epilogue, with certain assumptions made. + + Those assumptions are as: + 1. Fusion is over the N dimension. + 2. Top-K is either 2 or 4 elements, and the value is static (meaning two kernels have to be + compiled to support both.) + 3. The GEMM tile shape along N is greater than or equal to problem size + along N. + + + The example runs the fused GEMM kernel, along with a standard unfused host reference, and + manually performs Top-K and softmax, and compares the error between tensors. + + Note that some numerical error (smaller than 1e-5) is to be expected, but this is true + in most efficient reduction kernels, because floating point addition is not necessarily + associative. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +static constexpr int TopK = 2; +static constexpr bool EnableTopKSoftmax = TopK > 1; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = void; +using LayoutC = cutlass::layout::RowMajor; +constexpr int AlignmentC = 1; + +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for C and D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for output +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of output in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + +// Top-K + Softmax fusion operation +using FusionOperation = std::conditional_t, + typename cutlass::epilogue::fusion::LinearCombination +>; + +// The fusion op only allows for epilogue tiles matching the mainloop tile. +using EpilogueTileType = decltype(cute::take<0,2>(TileShape{})); + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideD stride_D; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + int iterations = 1000; + int m = 16, n = 8, k = 64, l = 1; + double eps = 1e-5; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("eps", eps); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "61_hopper_gemm_with_topk_and_softmax\n\n" + << " Hopper FP8 GEMM with Top-K and softmax fusion.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --eps= Threshold of numerical verification. Default: 1e-5.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "61_hopper_gemm_with_topk_and_softmax" << " --m=16 --n=8 --k=1024 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + float alpha() const { + return 1.f / static_cast(k); + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + cutlass::reference::host::TensorFillRandomUniform( + view, seed, /* max = */ 1, /* min = */ -1, /* bits = */ 2); + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_D.sync_device(); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {options.alpha(), 0.f}, // alpha, beta + nullptr, stride_D, + tensor_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + unused_t, + decltype(D), + unused_t, // bias + unused_t, // aux + unused_t, // valpha + unused_t // vbeta + > epilogue_params; + + epilogue_params.D = D; + epilogue_params.alpha = options.alpha(); + epilogue_params.beta = 0.f; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + if constexpr (EnableTopKSoftmax) { + // top-K + softmax + for (int i = 0; i < options.m; ++i) { + + // Find Top-K + cutlass::Array top_k; + top_k.fill(-cutlass::platform::numeric_limits::infinity()); + for (int j = 0; j < options.n; ++j) { + auto val = static_cast(tensor_ref_D.host_view().ref().at({i, j})); + for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) { + if (val > top_k[top_k_idx]) { + // Shift down + for (int l = TopK - 1; l > top_k_idx; --l) { + top_k[l] = top_k[l - 1]; + } + top_k[top_k_idx] = val; + break; + } + } + } + + // This formulation of top-K + softmax only works when it is + // guaranteed that none of the top-K elements are repeated! + // If this is the case, the device kernel can also make mistakes, because + // A. Once the top-K values are reduced, and the operation is being applied, + // there is no way to tell repeated elements apart, so none are masked. + // B. The softmax sum of exps will be incorrect (because the repeated elements + // are not repeated in it.) + + ElementAccumulator max = top_k[0]; + ElementAccumulator sum = ElementAccumulator(0.f); + for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) { + sum = sum + cutlass::fast_exp(top_k[top_k_idx] - max); + } + + for (int j=0; j < options.n; ++j) { + auto val = tensor_ref_D.host_view().ref().at({i, j}); + if (val < top_k[TopK - 1]) { + tensor_ref_D.host_view().ref().at({i, j}) = static_cast(0.f); + } else { + // Softmax + auto softmax_val = cutlass::fast_exp(val - max) / sum; + tensor_ref_D.host_view().ref().at({i, j}) = static_cast(softmax_val); + } + } + } + } + + // compare_reference + tensor_D.sync_host(); + + double err = cutlass::reference::host::TensorRelativeErrorMetric( + tensor_D.host_view(), + tensor_ref_D.host_view()); + bool passed = err < options.eps; + + if (options.m <= 32 && options.n <= 32) { + std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n"; + std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n"; + } + + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl; + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt b/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt new file mode 100644 index 0000000000..7d9160a733 --- /dev/null +++ b/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 61_hopper_gemm_with_topk_and_softmax + 61_hopper_gemm_with_topk_and_softmax.cu + ) diff --git a/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu new file mode 100644 index 0000000000..c3f1ce709a --- /dev/null +++ b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu @@ -0,0 +1,596 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Sparse GEMM example. + + This example demonstrates how to construct and run a structured sparse GEMM kernel + on NVIDIA Hopper architecture. + +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size for sparse kernel +using TileShapeRef = Shape<_128,_128, _64>; // Threadblock-level tile size for reference (dense) kernel +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; // Kernel schedule policy +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue schedule policy + +using ProblemShape = Shape; + +// Sparse kernel setup + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference (dense) kernel setup + +using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShapeRef, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShapeRef, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopRef, + CollectiveEpilogue +>; + +using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + +// Layouts +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// Layouts for reference (non-sparse) tensors +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; + +using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +// Offline compressor kernel +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm90>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +ProblemShape problem_shape; + +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +LayoutA layout_A; +LayoutE layout_E; + +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_A_compressed; +cutlass::DeviceAllocation block_E; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_D_ref; + +#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(5120), n(4096), k(16384), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "62_hopper_sparse_gemm\n\n" + << " Hopper Sparse GEMM example.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM (batch size)\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "62_hopper_sparse_gemm" << " --m=4096 --n=5120 --k=8192 --l=1 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Make A structured sparse by replacing elements with 0 and compress it +bool sparsify_and_compress() +{ + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + block_A_compressed.reset(M * KC * L); + block_E.reset(ME * KE * L); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + + // Random sparsification is performed on host + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast(seed + 2024)); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments { + problem_shape, + { block_A.get(), + stride_A, + block_A_compressed.get(), + block_E.get() }, + {hw_info} }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(Options const& options) { + + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + auto [M, N, K, L] = problem_shape; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // Allocate memory for tensors + block_A.reset(M * K * L); + block_B.reset(N * K * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_D_ref.reset(M * N * L); + + // Fill input tensors with data + initialize_block(block_A, seed + 2021); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2023); + + // Replace 0 in A with 1 to avoid metadata changes + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + for (size_t i = 0; i < block_A.size(); ++i) if (block_A_host[i] == ElementA(0)) block_A_host[i] = ElementA(1.0); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + if (!sparsify_and_compress()) { + return false; + }; + + // Build the compressed/metadata layouts + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + + return true; +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments make_args(Options const& options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E }, + { { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) }, + block_C.get(), stride_C, block_D.get(), stride_D } + }; + + return arguments; +} + +typename GemmRef::Arguments make_args_ref(Options const& options) +{ + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A.get(), stride_A, block_B.get(), stride_B }, + { { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) }, + block_C.get(), stride_C, block_D_ref.get(), stride_D } + }; + + return arguments; +} + +template +void print_device_tensor(cute::Tensor const& t) +{ + // Assumes size = cosize, i.e. compact tensor + std::vector data_host(t.size()); + cutlass::device_memory::copy_to_host(data_host.data(), t.data(), t.size()); + auto t_host = cute::make_tensor(data_host.data(), t.layout()); + cute::print_tensor(t_host); +} + +bool verify(Options const& options) { + CUDA_CHECK(cudaDeviceSynchronize()); + + bool passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size()); + +#if 0 + if (!passed) { + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + cute::print("A (original): "); print_device_tensor(make_tensor(block_A.get(), make_shape(M, K, L), stride_A)); + cute::print("A (compressed): "); print_device_tensor(make_tensor(block_A_compressed.get(), make_shape(M, KC, L), stride_A_compressed)); + cute::print("E (physical): "); print_device_tensor(make_tensor(block_E.get(), make_shape(ME, KE, L), stride_E)); + cute::print("E (logical): "); print_device_tensor(make_tensor(block_E.get(), upcast(layout_E))); + cute::print("B: "); print_device_tensor(make_tensor(block_B.get(), make_shape(N, K, L), stride_B)); + cute::print("C: "); print_device_tensor(make_tensor(block_C.get(), make_shape(M, N, L), stride_C)); + cute::print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(M, N, L), stride_D)); + cute::print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(M, N, L), stride_D)); + } +#endif + + return passed; +} + +template +struct Runner +{ + using Arguments = typename Gemm::Arguments; + + Runner(Arguments args): arguments(args) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + workspace.reset(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + } + + void run() { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + + void benchmark(Options const& options) { + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + run(); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + double avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double gflops = options.gflops(avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << gflops << std::endl; + } + } + + Gemm gemm; + Arguments arguments; + cutlass::device_memory::allocation workspace; +}; + +/// Execute the example (verification and timing) +void run(Options &options) { + bool init = initialize(options); + if (!init) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + Runner gemm(make_args(options)); + Runner gemm_ref(make_args_ref(options)); + + gemm.run(); + gemm_ref.run(); + + bool passed = verify(options); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + exit(EXIT_FAILURE); + } + + std::cout << "Sparse GEMM:" << std::endl; + gemm.benchmark(options); + + std::cout << "Dense GEMM:" << std::endl; + gemm_ref.benchmark(options); +} + +#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.2 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)) { + std::cerr << "This example requires CUDA 12.2 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + run(options); +#endif + + return EXIT_SUCCESS; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/62_hopper_sparse_gemm/CMakeLists.txt b/examples/62_hopper_sparse_gemm/CMakeLists.txt new file mode 100644 index 0000000000..cf55da4552 --- /dev/null +++ b/examples/62_hopper_sparse_gemm/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Sparse kernel in this example triggers an ICE in gcc 7.5 +if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) +cutlass_example_add_executable( + 62_hopper_sparse_gemm + 62_hopper_sparse_gemm.cu + ) +endif() diff --git a/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu new file mode 100644 index 0000000000..03c54a8ee9 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper FP8 GEMM + L2 Weight Prefetch + + This example implements a non-persistent warp-specialized GEMM kernel for the Hopper + architecture with programmatic dependent launch (PDL) enabling prefetching weights into + L2 cache. + + For more information about dependent launch refer to the CUDA programming guide: + https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization + + In some cases, PDL can result in a window where a previous kernel is not actively utilizing + DRAM, and the next kernel sits idle until the previous finishes. During this window, the next + kernel can begin loading a non-dependent operand (i.e. weights in a linear projection are + typically static) and cache it in L2. + + The kernel and collective mainloop assume operand `A` corresponds to weights and operand `B` + corresponds to activations (so we can have very small batch/token count). + After initialization, the prefetch warp starts loading K tiles of `A` into an unused portion + of shared memory, and loads up to half of all K tiles that the same CTA would eventually load. + The exact number of K tiles loaded is determined by `args.mainloop.prefetch_ratio` \in + [0.0, 1.0]. Smaller values result in less prefetching, and larger values result in more. + Negative values result in a "best-effort" prefetch, meaning prefetcher will stop issuing weight + loads as soon as the activation DMA warp starts loading (as soon as it is signaled that the + previous kernel has flushed its memory.) + + The DMA warp responsible for loading `A` will also begin loading K tiles until it fills up + the available shared memory. + The DMA warp responsible for loading `B` will wait until activations are flushed to global + memory by the preceding kernel. + + Another mainloop parameter, `args.mainloop.overlap_ratio` \in [0.0, 1.0] determines how early + the next kernel (the one doing the prefetch) is launched. Smaller values result in greater + overlap, and larger values result in smaller overlap. Negative values disable PDL completely, + meaning there will be no overlap. This will make prefetch ineffective. + + These two runtime parameters should be tuned per problem size and GEMM config combination, and + if feasible, per-operation in an entire layer or model. + + NOTE: you must build this target with the following flag to enable Grid Dependency Control + instructions (GDC) in CUTLASS: + - CUTLASS_ENABLE_GDC_FOR_SM90 + + To lock persistence mode, power (350W), clocks (1005MHz) for evaluation (assumes device 0 and H100) + + $ sudo nvidia-smi -pm 1 -i 0 + + $ sudo nvidia-smi -i 0 -pl 350 + + $ sudo nvidia-smi -i 0 -lgc 1005 + + Example: + + $ mkdir build && cd build + + $ cmake .. -DCUTLASS_NVCC_ARCHS="90a" -DCUTLASS_ENABLE_GDC_FOR_SM90=1 + + $ cd examples/63_hopper_gemm_with_weight_prefetch + + $ make + + $ ./63_hopper_gemm_with_weight_prefetch --p=0.5 --o=0.5 +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "collective/dispatch_policy_extra.hpp" +#include "collective/builder.hpp" +#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp" + +#include "helper.h" +#include "gemm_with_weight_prefetch_commandline.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size +// Cluster_N > 1 is not supported yet. +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; +cutlass::HostTensor scalar_alpha; +cutlass::HostTensor scalar_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + double eff_bw; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + double eff_bw = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), eff_bw(eff_bw), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + initialize_tensor(tensor_C.host_view(), seed + 2024); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = scalar_alpha.device_data(); + fusion_args.beta_ptr = scalar_beta.device_data(); + + arguments.mainloop.overlap_ratio = options.overlap_ratio; + arguments.mainloop.prefetch_ratio = options.prefetch_ratio; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + unused_t, // aux + unused_t, // valpha + unused_t // vbeta + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0)); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0)); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0); + result.gflops = options.gflops(avg_runtime_s); + result.eff_bw = options.effective_bandwidth(avg_runtime_s, sizeof(ElementA), sizeof(ElementB), sizeof(ElementC), sizeof(ElementD)); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + std::cout << " Effective bandwidth: " << result.eff_bw << " GB/s" << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} diff --git a/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt b/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt new file mode 100644 index 0000000000..f48673241a --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include_directories( + . +) + +cutlass_example_add_executable( + 63_hopper_gemm_with_weight_prefetch + 63_hopper_gemm_with_weight_prefetch.cu + ) diff --git a/examples/63_hopper_gemm_with_weight_prefetch/README.md b/examples/63_hopper_gemm_with_weight_prefetch/README.md new file mode 100644 index 0000000000..5dac1cc6c2 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/README.md @@ -0,0 +1,82 @@ +# GEMM with L2 weight prefetch + +A non-persistent warp specialized GEMM directed at low latency inference. + +The kernel can optionally prefetch a portion of weights (operand `A`) into L2 cache while the +rest of the warps are waiting on the previous kernel to finish writing and flush its memory. +An example of this is normalization or reduction kernels that are immediately followed by a GEMM. + +It exposes two runtime parameters: +1. `overlap_ratio`: how early `griddepcontrol.launch_dependent_grids` is issued. + Default is `0.5`, meaning after approximately half of K tiles are loaded by DMA warps. +2. `prefetch_ratio`: what percentage of K tiles to prefetch. + Default is `-1.0`, meaning prefetching will stop as soon as other DMA warps are past + `griddepcontrol`. + +It is highly recommended to auto-tune these parameters per GEMM and according to some end to end +runtime (either an entire transformer layer or multiple, but probably not the entire model.) + +TMA loads use non-default cache hints: `A` (weights) are loaded with `EvictFirst`, and `B` (activation) +is loaded with `EvictLast`. + +## Getting started +To use this kernel in your own target, add this directory to your includes, and include the +following headers from this example: + +```cxx +#include "collective/dispatch_policy_extra.hpp" +#include "collective/builder.hpp" +#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp" +``` + +And then use either one of the new kernel schedules: + +```cxx +// Without separate warps for A and B +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetch; + +// With separate warps for A and B +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA; +``` + +The kernel with separate warps for A and B ( +`KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA`) +is expected to be more performant than the other, especially since it allows the kernel to load +weights into shmem ahead of the `griddepcontrol`. + +As for other GEMM parameters, Thread Block Cluster larger than 1 CTA are not yet supported, and +obviously the kernel layer implementation is warp specialized and uses the TMA, and other kernel +layers or collectives require reimplementation. + +## Example + +Using the example is mostly straightforward. +Just build, and run with your choice of `MNK`: + +```bash +./63_hopper_gemm_with_weight_prefetch --m=8192 --n=1 --k=8192 +``` + +You can also disable the overlap or try different overlap and prefetch ratios and see the +difference: + +```bash +echo "Without overlap and prefetch" +./63_hopper_gemm_with_weight_prefetch --o=-1.0 --p=-1.0 + +echo "Overlap ratio of 0.5, best effort prefetch" +./63_hopper_gemm_with_weight_prefetch --o=0.5 --p=-1.0 + +echo "Overlap ratio of 0.8, prefetch ratio of 0.7" +./63_hopper_gemm_with_weight_prefetch --o=0.8 --p=0.7 +``` + +However, note that the example still runs a single GEMM, and most of the performance improvement +is expected in end to end applications. + + +## Limitations +* The parameter defaults are typically not good choices, especially `prefetch_ratio`. + When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a + memory barrier before issuing every single TMA load, and in many cases this will slow down + prefetching to the point of being almost ineffective. diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp new file mode 100644 index 0000000000..57365a8b36 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "dispatch_policy_extra.hpp" +#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp" + +namespace cutlass::gemm::collective { + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp new file mode 100644 index 0000000000..37369176f9 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +namespace cutlass::gemm { + +// Standard non-persistent kernel with a single producer warp, and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp is waiting on griddepcontrol. +// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and +// according to prefetch ratio. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { }; + +// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not +// wait on griddepcontrol and loads immediately. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { }; + +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch +> +struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; +}; + +} // namespace cutlass::gemm diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000..710224d78c --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp @@ -0,0 +1,867 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/arch/grid_dependency_control.h" + +#include "dispatch_policy_extra.hpp" + +#include "../pipeline/prefetch_pipeline_sm90.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedWithPrefetch, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1"); + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + static constexpr int PrefetchStages = 4; + static constexpr int PrefetchInitialStages = 1; + // This determines how much shmem we set aside for prefetch. + // We don't reuse anything loaded by prefetcher, so we can keep + // loading into the same place -- there will be a conflict when + // writing, but it doesn't affect performance as much as the doors + // that this opens. + static constexpr int PrefetchStagesActual = 1; + using PrefetcherPipeline = cutlass::PrefetchPipeline; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages); + static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages); + + using PrefetchSmemLayoutA = decltype(make_layout(make_shape( + cute::Int(SmemLayoutA{})>{}, + cute::Int(SmemLayoutA{})>{}, + cute::Int{}))); + + static constexpr auto prefetch_smem_size = cute::cosize_v; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + // Defined outside the class where it's used, to work around MSVC issues + using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned smem_prefetch; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + PrefetcherPipelineStorage prefetcher_pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.overlap_ratio, + args.prefetch_ratio + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + if (args.overlap_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + if (args.prefetch_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + return true; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // We have to wait on dependent grids because of B. + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_MK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + // Don't wait on dependent grids when loading `A`, because + // we assume `A` (weights) are static. + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_NK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + prefetch_MK( + Params const& mainloop_params, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0; + float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio; + int prefetch_iters = static_cast(static_cast(k_tile_count) * 0.5 * prefetch_ratio); + prefetch_iters = min(k_tile_count, ((prefetch_iters + PrefetchStages - 1) / PrefetchStages) * PrefetchStages); + + Tensor sA = make_tensor( + make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + uint32_t prefetcher_stage = 0; + uint32_t prefetcher_phase = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) { + + if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) { + break; + } + + prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= PrefetchStages); + using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType; + BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage); + + int write_stage = 0; + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + ++k_tile_iter; + + prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase); + } + prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp b/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp new file mode 100644 index 0000000000..6be87768ee --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float overlap_ratio = 0.5f, prefetch_ratio = 0.5f; + int iterations = 1000; + int n = 64, m = 1280, k = 8192, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f); + cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "63_hopper_gemm_with_weight_prefetch\n\n" + << " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n" + << " For more details please refer to the source file.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --p= Prefetch ratio\n" + << " --o= Overlap ratio\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "63_hopper_gemm_with_weight_prefetch" << + " --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + /// Compute effective bandwidth in GB/sec + double effective_bandwidth( + double runtime_s, + size_t bytes_a, + size_t bytes_b, + size_t bytes_c, + size_t bytes_d + ) const + { + static double const kBytesPerGiB = double(1ull << 30); + + double bytes_in = + (double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A + (double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B + (beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C + double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D + + double gb_total = (bytes_in + bytes_out) / kBytesPerGiB; + return gb_total / runtime_s; + } +}; diff --git a/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000..6e33d8fc62 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp @@ -0,0 +1,561 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" + +#include "../collective/dispatch_policy_extra.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +// GEMM + Prefetch for the A tensor + (optional) split DMA warps +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v + > +> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr bool SplitWarps = cute::is_same_v; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "TMA warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Kernel level shared memory storage + struct SharedStorage { + // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) PrefetcherPipelineStorage prefetcher; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + return implementable; + } + + static + size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + // Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK. + // Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused. + // Both modes use Warp1 to prefetch. + enum class ProducerWarpRole { + Warp0 = 0, + PrefetchMK = 1, + Warp2 = 2, + UnusedWarp = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + if (warp_group_role == WarpGroupRole::Producer && ( + producer_warp_role == ProducerWarpRole::Warp0 || + producer_warp_role == ProducerWarpRole::Warp2)) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + bool should_prefetch = params.mainloop.prefetch_ratio > 0; + using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline; + typename PrefetcherPipeline::Params prefetcher_pipeline_params; + prefetcher_pipeline_params.num_prefetchers = 1; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + prefetcher_pipeline_params.should_prefetch = should_prefetch; + prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk; + } + PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + // Non-prefetcher warps arrive and wait, + // Prefetcher warp can go ahead without waiting. + cute::cluster_arrive_relaxed(); + if (warp_group_role != WarpGroupRole::Producer || + producer_warp_role != ProducerWarpRole::PrefetchMK) { + cute::cluster_wait(); + } + return [] () {}; + } + else { + // __syncthreads() but only for non prefetcher warps + if (should_prefetch) { + + // Use a named barrier to let the prefetcher warp start loading into the L2 + // without waiting to sync with all other warps. + // All other warps need to sync because the mainloop pipeline init + // should be visible to all of them. + // Prefetcher has its own barriers, and the only warps it would need to sync + // with would be the DMA warps. + using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier; + auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z, + /*reserved_named_barriers_*/ 14); + // Prefetcher warp doesn't arrive on this barrier. + auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp, + /*reserved_named_barriers_*/ 15); + + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + __syncwarp(); + prefetcher_arrive_barrier.arrive(); + } + else if (warp_group_role == WarpGroupRole::Producer) { + prefetcher_arrive_barrier.arrive_and_wait(); + cluster_arrive_barrier.arrive_and_wait(); + } + else { + prefetcher_arrive_barrier.arrive(); + cluster_arrive_barrier.arrive_and_wait(); + } + } else { + __syncthreads(); + } + return [] () {}; + } + } (); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + if (producer_warp_role == ProducerWarpRole::Warp0) { + if constexpr(SplitWarps) { + collective_mainloop.load_NK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + else { + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + // Ensure warp is converged before issuing epilogue loads + __syncwarp(); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) { + collective_mainloop.load_MK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) { + collective_mainloop.prefetch_MK( + params.mainloop, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next + ); + } +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp b/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp new file mode 100644 index 0000000000..7abd39ccfc --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/container/array.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +namespace detail { + +// MSVC work-around +template +struct PrefetcherPipelineSharedStorage { + using TransactionBarrier = cutlass::arch::ClusterTransactionBarrier; + using Barrier = cutlass::arch::ClusterBarrier; + + TransactionBarrier tma_barrier[Stages]; + Barrier producer_ready_barrier; +}; + +} // end namespace detail + +using namespace cute; + +// Prefetcher pipeline is modeled after PipelineTmaAsync, with a cluster transaction +// barrier providing control over the number of concurrent outstanding TMA loads. +// There is also an additional cluster barrier which is only used when `prefetch_ratio` is unset. +// `prefetch_ratio` determines how many K tiles get loaded, and when unset, the prefetcher checks +// whether DMA warps are done waiting on griddepcontrol, and if so, stops issuing more TMA loads. +template +class PrefetchPipeline { +public : + static constexpr uint32_t Stages = Stages_; + using SharedStorage = detail::PrefetcherPipelineSharedStorage; + + using TransactionBarrier = typename SharedStorage::TransactionBarrier; + using Barrier = typename SharedStorage::Barrier; + using PrefetcherBarrierType = typename TransactionBarrier::ValueType; + + struct Params { + uint32_t transaction_bytes = 0; + uint32_t num_prefetchers = 1; + bool should_prefetch = false; + }; + + // Constructor + CUTLASS_DEVICE + PrefetchPipeline(SharedStorage& storage, Params params) + : params_(params) + , tma_barrier_ptr_(&storage.tma_barrier[0]) + , producer_ready_barrier_ptr_(&storage.producer_ready_barrier) { + + int lane_predicate = cute::elect_one_sync(); + if (params.should_prefetch && lane_predicate) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; ++i) { + tma_barrier_ptr_[i].init(params.num_prefetchers); + } + producer_ready_barrier_ptr_[0].init(1); + } + } + + CUTLASS_DEVICE + void producer_arrive() { + if (params_.should_prefetch) { + producer_ready_barrier_ptr_[0].arrive(); + } + } + + CUTLASS_DEVICE + bool have_producers_arrived() { + if (params_.should_prefetch) { + uint32_t barrier_status_ = producer_ready_barrier_ptr_[0].try_wait(0); + auto barrier_status = static_cast(barrier_status_); + if (barrier_status == BarrierStatus::WaitDone) { + return true; // exit prefetcher loop + } + return false; + } + return true; + } + + CUTLASS_DEVICE + void prefetcher_acquire(uint32_t stage, uint32_t phase, bool should_wait) { + if (params_.should_prefetch) { + if (should_wait) { + tma_barrier_ptr_[stage].wait(phase ^ 1); + } + tma_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); + } + } + + CUTLASS_DEVICE + void advance_prefetcher_state(uint32_t& stage, uint32_t& phase) { + if (params_.should_prefetch) { + stage++; + if (stage == Stages) { + stage = 0; + phase ^= 1; + } + } + } + + CUTLASS_DEVICE + void prefetcher_tail(uint32_t stage, uint32_t phase) { + if (params_.should_prefetch) { + // Wait on any already-issued loads + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < stage; ++i) { + tma_barrier_ptr_[i].wait(phase); + } + } + } + + CUTLASS_DEVICE + PrefetcherBarrierType* prefetcher_get_barrier(uint32_t stage) { + return reinterpret_cast(&tma_barrier_ptr_[stage]); + } + +private : + TransactionBarrier* tma_barrier_ptr_ = nullptr; + Barrier* producer_ready_barrier_ptr_ = nullptr; + Params params_; + +}; + +} // end namespace cutlass diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 9cb125d988..6486d71435 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -140,6 +140,9 @@ foreach(EXAMPLE 57_hopper_grouped_gemm 58_ada_fp8_gemm 59_ampere_gather_scatter_conv + 61_hopper_gemm_with_topk_and_softmax + 62_hopper_sparse_gemm + 63_hopper_gemm_with_weight_prefetch ) add_subdirectory(${EXAMPLE}) diff --git a/examples/cute/tutorial/tiled_copy.cu b/examples/cute/tutorial/tiled_copy.cu index d370320b1b..87ad873ce6 100644 --- a/examples/cute/tutorial/tiled_copy.cu +++ b/examples/cute/tutorial/tiled_copy.cu @@ -186,8 +186,8 @@ int main(int argc, char** argv) return -1; } // Equivalent check to the above - if (not weakly_compatible(block_shape, tensor_shape)) { - std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl; + if (not evenly_divides(tensor_shape, block_shape)) { + std::cerr << "Expected the block_shape to evenly divide the tensor shape." << std::endl; return -1; } diff --git a/include/cute/algorithm/clear.hpp b/include/cute/algorithm/clear.hpp index 1c7dd5a334..0b3a8eaa1d 100644 --- a/include/cute/algorithm/clear.hpp +++ b/include/cute/algorithm/clear.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::fill namespace cute { diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp index b2be11717f..9d080116da 100644 --- a/include/cute/algorithm/cooperative_copy.hpp +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -31,12 +31,14 @@ #pragma once #include - -#include -#include - -#include +#include +#include // cute::logical_divide +#include // cute::Swizzle +#include // cute::get_nonswizzle_portion +#include // cute::Tensor #include +#include +#include namespace cute { diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index da03bfbd11..2c91ce6f45 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -434,8 +434,8 @@ cooperative_gemm(uint32_t thread_idx, static_assert(is_convertible_v>, TypeC>, "CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); - static constexpr bool compat = weakly_compatible(tile_shape(TiledMMA{}), - make_shape(size<0>(sA), size<0>(sB), size<1>(sA))); + static constexpr bool compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); if constexpr (compat) { detail::cooperative_gemm_no_predication( thread_idx, tiled_mma, alpha, sA, sB, beta, sC, diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index 2a37995eea..c2decd15d7 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -30,14 +30,10 @@ **************************************************************************************************/ #pragma once -#include - -#include - -#include -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::TrivialPredTensor +#include // cute::Copy_Atom namespace cute { diff --git a/include/cute/algorithm/functional.hpp b/include/cute/algorithm/functional.hpp index 8e7a58a5bc..ef80d018d7 100644 --- a/include/cute/algorithm/functional.hpp +++ b/include/cute/algorithm/functional.hpp @@ -30,10 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::max, cute::min +#include // cute::conj /** C++14 extensions */ diff --git a/include/cute/algorithm/prefetch.hpp b/include/cute/algorithm/prefetch.hpp index 0d638ab58f..c39f63acdd 100644 --- a/include/cute/algorithm/prefetch.hpp +++ b/include/cute/algorithm/prefetch.hpp @@ -30,11 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::Copy_Atom namespace cute { diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index 616960a54a..c87ce682d1 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -44,7 +44,7 @@ /// Code guidelines and style preferences: /// /// For perfect forwarding, don't use std::forward, because it may not -/// be defined in device code when compiling with NVRTC. Instead, use +/// be defined in device code when compiling with NVRTC. Instead, use /// `static_cast(parameter_name)`. /// /// CuTe generally does not bother forwarding functions, as @@ -52,24 +52,9 @@ /// /// Throughout CUTLASS, cute::make_tuple always needs to be called /// namespace-qualified, EVEN If inside the cute namespace and/or in -/// scope of a "using namespace cute" declaration. Otherwise, the +/// scope of a "using namespace cute" declaration. Otherwise, the /// compiler may select std::make_tuple instead of cute::make_tuple, -/// due to argument-dependent lookup. Two problems may result from -/// that. -/// -/// 1. Functions have an unexpected return type (std::tuple instead of -/// cute::tuple), so functions that take cute::tuple parameters -/// fail to compile (generally inside functions that have template -/// parameters expected to be cute::tuple). -/// -/// 2. std::tuple does not have the required __host__ __device__ -/// markings, so the CUDA compiler complains if you use it in -/// device code. -/// -/// cute::make_tuple will occur more often than std::make_tuple would -/// in modern C++ code, because cute::tuple's design deprioritizes -/// correct operation of CTAD (constructor template argument -/// deduction) in favor of implementation simplicity. +/// due to argument-dependent lookup. namespace cute { @@ -145,6 +130,8 @@ transform_apply(T&& t, F&& f, G&& g) } else { return g(f(static_cast(t))); } + + CUTE_GCC_UNREACHABLE; } template @@ -157,6 +144,8 @@ transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) } else { return g(f(static_cast(t0), static_cast(t1))); } + + CUTE_GCC_UNREACHABLE; } template @@ -169,6 +158,8 @@ transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) } else { return g(f(static_cast(t0), static_cast(t1), static_cast(t2))); } + + CUTE_GCC_UNREACHABLE; } // @@ -401,71 +392,36 @@ filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) namespace detail { -// This impl compiles much faster than cute::apply and variadic args -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&&, V&& v, F&&, seq<>) -{ - return v; -} - -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&& t, V&& v, F&& f, seq) -{ - return f(static_cast(v), get(static_cast(t))); -} - -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&& t, V&& v, F&& f, seq) -{ - return f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))); -} - -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&& t, V&& v, F&& f, seq) -{ - return f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))); -} +template +struct FoldAdaptor { + template + CUTE_HOST_DEVICE constexpr auto operator|(X&& x) { + auto r = fn_(val_, static_cast(x)); + return FoldAdaptor{fn_, r}; + } + Fn fn_; + Val val_; +}; -template +template CUTE_HOST_DEVICE constexpr auto -fold(T&& t, V&& v, F&& f, seq) +fold(T&& t, V const& v, F&& f, seq) { - return f(f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))); + return (FoldAdaptor{f,v} | ... | get(static_cast(t))).val_; } -template -CUTE_HOST_DEVICE constexpr -auto -fold(T&& t, V&& v, F&& f, seq) -{ - return fold(static_cast(t), - f(f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), - f, - seq{}); -} } // end namespace detail template CUTE_HOST_DEVICE constexpr auto -fold(T&& t, V&& v, F&& f) +fold(T&& t, V const& v, F&& f) { if constexpr (is_tuple>::value) { - return detail::fold(static_cast(t), - static_cast(v), - f, - tuple_seq{}); + return detail::fold(static_cast(t), v, f, tuple_seq{}); } else { - return f(static_cast(v), static_cast(t)); + return f(v, static_cast(t)); } CUTE_GCC_UNREACHABLE; @@ -477,10 +433,7 @@ auto fold_first(T&& t, F&& f) { if constexpr (is_tuple>::value) { - return detail::fold(static_cast(t), - get<0>(static_cast(t)), - f, - make_range<1,tuple_size>::value>{}); + return detail::fold(static_cast(t), get<0>(t), f, make_range<1,tuple_size>::value>{}); } else { return t; } @@ -536,13 +489,23 @@ CUTE_HOST_DEVICE constexpr auto take(T const& t) { - return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); + if constexpr (E == -1) { + if constexpr (is_tuple::value) { + return take::value>(t); + } else { + return take(t); + } + } else + if constexpr (B <= E) { + return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; } -// // Select tuple elements with given indices. -// - template CUTE_HOST_DEVICE constexpr auto @@ -551,19 +514,6 @@ select(T const& t) return cute::make_tuple(get(t)...); } -template -CUTE_HOST_DEVICE constexpr -auto -select(T const& t, Indices const& indices) -{ - if constexpr (is_tuple::value) { - return cute::transform(indices, [&t](auto i) { return select(t, i); }); - } else { - static_assert(is_static::value, "Order must be static"); - return get(t); - } -} - // Wrap non-tuples into rank-1 tuples or forward template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp index 27a34d7773..8fff51be8e 100644 --- a/include/cute/arch/cluster_sm90.hpp +++ b/include/cute/arch/cluster_sm90.hpp @@ -150,7 +150,7 @@ CUTE_DEVICE dim3 cluster_shape() } // Get 1D ctaid in a cluster. -CUTLASS_DEVICE uint32_t block_rank_in_cluster() +CUTE_DEVICE uint32_t block_rank_in_cluster() { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t rank; @@ -162,7 +162,7 @@ CUTLASS_DEVICE uint32_t block_rank_in_cluster() } // Set the destination block-ID in cluster for a given SMEM Address -CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) +CUTE_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t result; diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp new file mode 100644 index 0000000000..84d7779a34 --- /dev/null +++ b/include/cute/arch/config.hpp @@ -0,0 +1,50 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTLASS_ARCH_MMA_SMxx_ENABLED + +// TMA instructions +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_TMA_SM90_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +#endif + +// STSM +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cute/arch/copy_sm50.hpp b/include/cute/arch/copy_sm50.hpp index 9cf0efcdf5..925d9ebe37 100644 --- a/include/cute/arch/copy_sm50.hpp +++ b/include/cute/arch/copy_sm50.hpp @@ -40,8 +40,8 @@ namespace cute { - -struct SM50_Shuffle_U32_2x2Trans +// Shuffle data between thread pair (0, 1), (2, 3), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR1 { using SRegisters = uint32_t[2]; using DRegisters = uint32_t[2]; @@ -68,5 +68,31 @@ struct SM50_Shuffle_U32_2x2Trans } }; +// Shuffle data between thread pair (0, 4), (1, 5), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR4 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = threadIdx.x & 4 ? src0 : src1; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 4); + + // Replace detination register with shuffle result. + if (threadIdx.x & 0x4) { + dst0 = y0; + } + else { + dst1 = y0; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + } // end namespace cute diff --git a/include/cute/arch/copy_sm90.hpp b/include/cute/arch/copy_sm90.hpp index e5684ec469..bcb3b7d19c 100644 --- a/include/cute/arch/copy_sm90.hpp +++ b/include/cute/arch/copy_sm90.hpp @@ -30,21 +30,10 @@ **************************************************************************************************/ #pragma once -#include - +#include // CUTE_HOST_DEVICE +#include // CUTE_ARCH_TMA_SMxx_ENABLED #include -// Config -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -# define CUTE_ARCH_STSM_SM90_ENABLED -# define CUTE_ARCH_TMA_SM90_ENABLED -#endif - -#if defined(CUTE_ARCH_TMA_SM90_ENABLED) && \ - ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) -# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED -#endif - namespace cute { diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 21e473ede9..25a252a8e7 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -30,6 +30,8 @@ **************************************************************************************************/ #pragma once +#include "cutlass/numeric_types.h" + #if !defined(__CUDACC_RTC__) #include #include @@ -37,6 +39,8 @@ #include +#include // cute::cast_smem_ptr_to_uint +#include // CUTE_ARCH_TMA_SMxx_ENABLED #include #include @@ -134,6 +138,10 @@ enum class SmemSwizzleBits : uint8_t { B128 = 3, }; +enum class SmemSwizzleBase : uint8_t { + SWIZZLE_BASE_16B = 0, +}; + enum class OOBFill : uint8_t { ZERO = 0, CONSTANT = 1, @@ -201,13 +209,21 @@ to_CUtensorMapDataType() { } inline CUtensorMapSwizzle -to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { +to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { switch (t) { - default: assert(false && "Unknown SmemSwizzleBits!"); - case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE; - case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B; - case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B; - case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B; + default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + case SmemSwizzleBits::DISABLE: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_NONE; + case SmemSwizzleBits::B32: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 32B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_32B; + case SmemSwizzleBits::B64: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_64B; + case SmemSwizzleBits::B128: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 128B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_128B; } } @@ -282,7 +298,7 @@ tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr, "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;" :: "l"(gmem_int_desc), "l"(new_desc_addr)); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -295,15 +311,11 @@ tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc, #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); - uint64_t const smem_int64_desc = 0; - asm volatile ( - "cvt.u64.u32 %0, %1;" - :: "l"(smem_int64_desc), "r"(smem_int_desc)); asm volatile ( "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" - :: "l"(smem_int64_desc), "l"(new_desc_addr)); + :: "r"(smem_int_desc), "l"(new_desc_addr)); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -331,7 +343,6 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor :: "l"(smem_int64_desc), "r"(prob_shape[2])); // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) - // 4 LSBs are not included asm volatile ( "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[1])); @@ -339,6 +350,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[2])); #else + // 4 LSBs are not included asm volatile ( "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); @@ -347,7 +359,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); #endif #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -366,7 +378,7 @@ tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescripto "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" :: "l"(gmem_int_desc), "r"(smem_int_desc)); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -381,7 +393,7 @@ tma_descriptor_fence_release() #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) asm volatile ("fence.proxy.tensormap::generic.release.gpu;"); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } @@ -400,13 +412,8 @@ tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr) : : "l"(gmem_int_desc) : "memory"); - asm volatile ( - "cvta.global.u64 %0, %0;" - : - : "l"(gmem_int_desc), "l"(gmem_int_desc) - : "memory"); #else - CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); #endif } diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index 1851482119..fb33d63cad 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -32,8 +32,11 @@ #include +#include // CUTE_ARCH_TMA_SMxx_ENABLED #include #include +#include "cutlass/arch/synclog.hpp" + namespace cute { @@ -52,6 +55,7 @@ struct SM90_TMA_LOAD_1D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3}], [%2], %4;" @@ -97,6 +101,7 @@ struct SM90_TMA_LOAD_2D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4}], [%2], %5;" @@ -142,6 +147,7 @@ struct SM90_TMA_LOAD_3D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5}], [%2], %6;" @@ -187,6 +193,7 @@ struct SM90_TMA_LOAD_4D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" @@ -232,6 +239,7 @@ struct SM90_TMA_LOAD_5D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" @@ -355,6 +363,7 @@ struct SM90_TMA_LOAD_IM2COL_3D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" @@ -405,6 +414,7 @@ struct SM90_TMA_LOAD_IM2COL_4D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" @@ -455,6 +465,7 @@ struct SM90_TMA_LOAD_IM2COL_5D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" @@ -565,7 +576,7 @@ struct SM90_TMA_LOAD_IM2COL struct SM90_TMA_LOAD_MULTICAST_1D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0) { @@ -573,13 +584,14 @@ struct SM90_TMA_LOAD_MULTICAST_1D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4}], [%2], %3;" + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4}], [%2], %3, %5;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0) + "r"(crd0), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -590,7 +602,7 @@ struct SM90_TMA_LOAD_MULTICAST_1D struct SM90_TMA_LOAD_MULTICAST_2D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1) { @@ -598,13 +610,14 @@ struct SM90_TMA_LOAD_MULTICAST_2D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4, %5}], [%2], %3;" + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0), "r"(crd1) + "r"(crd0), "r"(crd1), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -615,7 +628,7 @@ struct SM90_TMA_LOAD_MULTICAST_2D struct SM90_TMA_LOAD_MULTICAST_3D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { @@ -623,13 +636,14 @@ struct SM90_TMA_LOAD_MULTICAST_3D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4, %5, %6}], [%2], %3;" + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6}], [%2], %3, %7;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0), "r"(crd1), "r"(crd2) + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -640,7 +654,7 @@ struct SM90_TMA_LOAD_MULTICAST_3D struct SM90_TMA_LOAD_MULTICAST_4D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { @@ -648,13 +662,14 @@ struct SM90_TMA_LOAD_MULTICAST_4D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4, %5, %6, %7}], [%2], %3;" + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -665,7 +680,7 @@ struct SM90_TMA_LOAD_MULTICAST_4D struct SM90_TMA_LOAD_MULTICAST_5D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { @@ -673,13 +688,14 @@ struct SM90_TMA_LOAD_MULTICAST_5D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); asm volatile ( - "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" - " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;" + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), - "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -690,39 +706,39 @@ struct SM90_TMA_LOAD_MULTICAST_5D struct SM90_TMA_LOAD_MULTICAST { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0) { - return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0); + return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1) { - return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1); + return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { - return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2); + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { - return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3); + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { - return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4); + return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); } using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; @@ -744,6 +760,7 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" @@ -772,6 +789,7 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" @@ -800,6 +818,7 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); // Copy from global to shared::cluster. asm volatile ( "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" @@ -871,6 +890,7 @@ struct SM90_TMA_STORE_1D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" : @@ -893,6 +913,7 @@ struct SM90_TMA_STORE_2D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" : @@ -915,6 +936,7 @@ struct SM90_TMA_STORE_3D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" : @@ -937,6 +959,7 @@ struct SM90_TMA_STORE_4D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" : @@ -959,6 +982,7 @@ struct SM90_TMA_STORE_5D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" : @@ -1024,6 +1048,7 @@ struct SM90_TMA_STORE_IM2COL_3D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group" " [%0, {%2, %3, %4}], [%1];" @@ -1047,6 +1072,7 @@ struct SM90_TMA_STORE_IM2COL_4D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group" " [%0, {%2, %3, %4, %5}], [%1];" @@ -1070,6 +1096,7 @@ struct SM90_TMA_STORE_IM2COL_5D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group" " [%0, {%2, %3, %4, %5, %6}], [%1];" @@ -1112,6 +1139,7 @@ struct SM90_TMA_STORE_IM2COL CUTE_HOST_DEVICE static void tma_store_fence() { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); asm volatile ("fence.proxy.async.shared::cta;"); #elif defined(__CUDA_ARCH__) CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -1122,6 +1150,7 @@ tma_store_fence() { CUTE_HOST_DEVICE static void tma_store_arrive() { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_tma_store_arrive(__LINE__); asm volatile("cp.async.bulk.commit_group;"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -1138,6 +1167,7 @@ tma_store_wait() { : : "n"(Count) : "memory"); + cutlass::arch::synclog_emit_tma_store_wait(__LINE__, Count); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); #endif @@ -1157,6 +1187,7 @@ struct SM90_TMA_REDUCE_ADD_1D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group [%0, {%2}], [%1];" : @@ -1179,6 +1210,7 @@ struct SM90_TMA_REDUCE_ADD_2D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.bulk_group [%0, {%2, %3}], [%1];" : @@ -1201,6 +1233,7 @@ struct SM90_TMA_REDUCE_ADD_3D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4}], [%1];" : @@ -1223,6 +1256,7 @@ struct SM90_TMA_REDUCE_ADD_4D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5}], [%1];" : @@ -1245,6 +1279,7 @@ struct SM90_TMA_REDUCE_ADD_5D #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); asm volatile ( "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" : diff --git a/include/cute/arch/mma.hpp b/include/cute/arch/mma.hpp index 5bfda7463c..6e06114a6c 100644 --- a/include/cute/arch/mma.hpp +++ b/include/cute/arch/mma.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::fma +#include // cute::fma namespace cute { diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp index d504bf39df..6ab29adc9d 100644 --- a/include/cute/arch/mma_sm90.hpp +++ b/include/cute/arch/mma_sm90.hpp @@ -32,7 +32,6 @@ #pragma once #include - #include // Config @@ -45,10 +44,12 @@ namespace cute { +namespace SM90 { + //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x4 TN -struct SM90_16x8x4_F64F64F64F64_TN +struct MMA_16x8x4_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[2]; @@ -73,7 +74,7 @@ struct SM90_16x8x4_F64F64F64F64_TN "d"(b0), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; @@ -81,7 +82,7 @@ struct SM90_16x8x4_F64F64F64F64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x8 TN -struct SM90_16x8x8_F64F64F64F64_TN +struct MMA_16x8x8_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[4]; @@ -106,7 +107,7 @@ struct SM90_16x8x8_F64F64F64F64_TN "d"(b0), "d"(b1), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; @@ -114,7 +115,7 @@ struct SM90_16x8x8_F64F64F64F64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x16 TN -struct SM90_16x8x16_F64F64F64F64_TN +struct MMA_16x8x16_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[8]; @@ -141,7 +142,7 @@ struct SM90_16x8x16_F64F64F64F64_TN "d"(b0), "d"(b1), "d"(b2), "d"(b3), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; @@ -149,7 +150,7 @@ struct SM90_16x8x16_F64F64F64F64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x4 TN -struct SM90_16x8x4_C64C64C64C64_TN +struct MMA_16x8x4_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[2]; @@ -175,28 +176,28 @@ struct SM90_16x8x4_C64C64C64C64_TN double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); - SM90_16x8x4_F64F64F64F64_TN::fma( + MMA_16x8x4_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), b0.real(), c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); - SM90_16x8x4_F64F64F64F64_TN::fma( + MMA_16x8x4_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), b0.real(), c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); - SM90_16x8x4_F64F64F64F64_TN::fma( + MMA_16x8x4_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), b0.imag(), d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); - SM90_16x8x4_F64F64F64F64_TN::fma( + MMA_16x8x4_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), b0.imag(), @@ -207,7 +208,7 @@ struct SM90_16x8x4_C64C64C64C64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x8 TN -struct SM90_16x8x8_C64C64C64C64_TN +struct MMA_16x8x8_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[4]; @@ -234,28 +235,28 @@ struct SM90_16x8x8_C64C64C64C64_TN double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); - SM90_16x8x8_F64F64F64F64_TN::fma( + MMA_16x8x8_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), a2.real(), a3.real(), b0.real(), b1.real(), c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); - SM90_16x8x8_F64F64F64F64_TN::fma( + MMA_16x8x8_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), a2.imag(), a3.imag(), b0.real(), b1.real(), c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); - SM90_16x8x8_F64F64F64F64_TN::fma( + MMA_16x8x8_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), b0.imag(), b1.imag(), d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); - SM90_16x8x8_F64F64F64F64_TN::fma( + MMA_16x8x8_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), a2.real(), a3.real(), b0.imag(), b1.imag(), @@ -266,7 +267,7 @@ struct SM90_16x8x8_C64C64C64C64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x16 TN -struct SM90_16x8x16_C64C64C64C64_TN +struct MMA_16x8x16_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[8]; @@ -296,7 +297,7 @@ struct SM90_16x8x16_C64C64C64C64_TN double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); - SM90_16x8x16_F64F64F64F64_TN::fma( + MMA_16x8x16_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), a2.real(), a3.real(), a4.real(), a5.real(), a6.real(), a7.real(), @@ -304,7 +305,7 @@ struct SM90_16x8x16_C64C64C64C64_TN c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); - SM90_16x8x16_F64F64F64F64_TN::fma( + MMA_16x8x16_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), a2.imag(), a3.imag(), a4.imag(), a5.imag(), a6.imag(), a7.imag(), @@ -312,7 +313,7 @@ struct SM90_16x8x16_C64C64C64C64_TN c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); - SM90_16x8x16_F64F64F64F64_TN::fma( + MMA_16x8x16_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), @@ -320,7 +321,7 @@ struct SM90_16x8x16_C64C64C64C64_TN d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); - SM90_16x8x16_F64F64F64F64_TN::fma( + MMA_16x8x16_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), a2.real(), a3.real(), a4.real(), a5.real(), a6.real(), a7.real(), @@ -331,17 +332,24 @@ struct SM90_16x8x16_C64C64C64C64_TN //////////////////////////////////////////////////////////////////////////////////////////////////// +} + } // namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// #include #include +#include +#include // cute::size +#include // cute::is_static +#include // cute::half_t, cute::float_e4m3_t, cute::tfloat32_t, etc +#include // cute::is_same_v //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cute { -namespace GMMA { +namespace SM90::GMMA { template < class ElementA, @@ -370,73 +378,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x256x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x240x16_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x224x16_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x208x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x192x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x176x16_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x160x16_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x144x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x128x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x112x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x96x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x80x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x64x16_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x48x16_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x32x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x16x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F16F16F16_SS{}; + return SM90::GMMA::MMA_64x8x16_F16F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -450,73 +458,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -530,73 +538,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -610,73 +618,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -690,73 +698,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -776,73 +784,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x256x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x240x16_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x224x16_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x208x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x192x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x176x16_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x160x16_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x144x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x128x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x112x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x96x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x80x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x64x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x48x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x32x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x16x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F32F16F16_SS{}; + return SM90::GMMA::MMA_64x8x16_F32F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -854,73 +862,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F32BF16BF16_SS{}; + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -934,73 +942,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x8_F32TF32TF32_SS_TN{}; + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1014,73 +1022,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E4M3E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1094,73 +1102,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E4M3E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1174,73 +1182,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E5M2E4M3_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1254,73 +1262,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E5M2E5M2_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1342,73 +1350,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8S8_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1422,73 +1430,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8U8_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1502,73 +1510,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8S8_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1582,73 +1590,73 @@ ss_op_selector() static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8U8_SS_TN{}; + return SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1678,12 +1686,11 @@ template < > CUTE_HOST_DEVICE constexpr auto -rs_op_selector() +ss_op_selector_sparse() { static_assert(is_static::value, "TileShape_MNK must be static."); static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); - static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); auto Tile_N = size<1>(TileShape_MNK{}); // F16 accumulator @@ -1691,76 +1698,76 @@ rs_op_selector() // Input A: half_t ; Input B: half_t if constexpr (is_same_v && is_same_v) { - static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F16F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1771,76 +1778,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1851,76 +1858,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -1931,76 +1938,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2011,76 +2018,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F16E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2097,76 +2104,76 @@ rs_op_selector() // Input A: half_t ; Input B: half_t if constexpr (is_same_v && is_same_v) { - static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F32F16F16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2175,76 +2182,76 @@ rs_op_selector() // Input A: bfloat16_t ; Input B: bfloat16_t else if constexpr (is_same_v && is_same_v) { - static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x16_F32BF16BF16_RS{}; + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2255,76 +2262,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x8_F32TF32TF32_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2335,76 +2342,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E4M3E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2415,76 +2422,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E4M3E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2495,76 +2502,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E5M2E4M3_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2575,76 +2582,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_F32E5M2E5M2_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2663,76 +2670,76 @@ rs_op_selector() if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2743,76 +2750,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2823,76 +2830,76 @@ rs_op_selector() else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); - static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8S8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2901,78 +2908,2726 @@ rs_op_selector() // Input A: uint8_t ; Input B: uint8_t else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8U8_RS_TN{}; + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector_sparse() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { - return SM90_64x240x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { - return SM90_64x224x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { - return SM90_64x208x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { - return SM90_64x176x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { - return SM90_64x160x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN{}; } #endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { - return SM90_64x144x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { - return SM90_64x112x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { - return SM90_64x80x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { - return SM90_64x48x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN{}; } #endif else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8U8_RS_TN{}; + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -2990,7 +5645,7 @@ rs_op_selector() } } -} // end namespace GMMA +} // end namespace SM90::GMMA } // end namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp index 1d6caba89d..a53a9748b4 100644 --- a/include/cute/arch/mma_sm90_desc.hpp +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -48,8 +48,7 @@ namespace cute { // GMMA Descriptor and utilities // GMMA enums and utilities -namespace GMMA -{ +namespace SM90::GMMA { enum class LayoutType : uint8_t { INTERLEAVE = 0, @@ -81,7 +80,7 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { } #endif // !defined(__CUDACC_RTC__) -} // end namespace GMMA +} // end namespace SM90::GMMA union GmmaDescriptor { @@ -146,7 +145,7 @@ print(GmmaDescriptor const& t) printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); - printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); + printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); #endif // !defined(__CUDACC_RTC__) } diff --git a/include/cute/arch/mma_sm90_gmma.hpp b/include/cute/arch/mma_sm90_gmma.hpp index aebb8fab5a..4dc01463b7 100644 --- a/include/cute/arch/mma_sm90_gmma.hpp +++ b/include/cute/arch/mma_sm90_gmma.hpp @@ -30,8 +30,10 @@ **************************************************************************************************/ #pragma once -#include -#include +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) # define CUTE_ARCH_MMA_SM90A_ENABLED @@ -47,6 +49,7 @@ void warpgroup_arrive() { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_arrive(__LINE__); asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); #else CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90A_ENABLED"); @@ -60,6 +63,7 @@ warpgroup_wait() { static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N); asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); @@ -72,6 +76,7 @@ void warpgroup_commit_batch() { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_commit_batch(__LINE__); asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); #else CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90A_ENABLED"); @@ -97,7 +102,7 @@ warpgroup_fence_operand(float& reg) { #endif } -namespace GMMA { +namespace SM90::GMMA { enum class Major { K = 0, @@ -114,7 +119,11 @@ enum class ScaleIn { One = 1 }; -} // namespace GMMA +enum class SparseSel { + Zero = 0, + One = 1 +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) @@ -127,7 +136,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F16F16F16_SS +struct MMA_64x8x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -141,6 +150,7 @@ struct SM90_64x8x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -156,7 +166,7 @@ struct SM90_64x8x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -170,7 +180,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F16F16F16_RS +struct MMA_64x8x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -187,6 +197,7 @@ struct SM90_64x8x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -202,7 +213,7 @@ struct SM90_64x8x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -216,7 +227,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F16F16F16_SS +struct MMA_64x16x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -230,6 +241,7 @@ struct SM90_64x16x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -245,7 +257,7 @@ struct SM90_64x16x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -259,7 +271,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F16F16F16_RS +struct MMA_64x16x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -276,6 +288,7 @@ struct SM90_64x16x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -291,7 +304,7 @@ struct SM90_64x16x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -305,7 +318,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F16F16F16_SS +struct MMA_64x32x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -320,6 +333,7 @@ struct SM90_64x32x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -336,7 +350,7 @@ struct SM90_64x32x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -350,7 +364,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F16F16F16_RS +struct MMA_64x32x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -368,6 +382,7 @@ struct SM90_64x32x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -384,7 +399,7 @@ struct SM90_64x32x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -399,7 +414,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F16F16F16_SS +struct MMA_64x48x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -415,6 +430,7 @@ struct SM90_64x48x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -433,7 +449,7 @@ struct SM90_64x48x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -449,7 +465,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F16F16F16_RS +struct MMA_64x48x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -468,6 +484,7 @@ struct SM90_64x48x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -486,7 +503,7 @@ struct SM90_64x48x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -501,7 +518,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F16F16F16_SS +struct MMA_64x64x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -518,6 +535,7 @@ struct SM90_64x64x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -537,7 +555,7 @@ struct SM90_64x64x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -551,7 +569,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F16F16F16_RS +struct MMA_64x64x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -571,6 +589,7 @@ struct SM90_64x64x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -590,7 +609,7 @@ struct SM90_64x64x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -605,7 +624,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F16F16F16_SS +struct MMA_64x80x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -623,6 +642,7 @@ struct SM90_64x80x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -644,7 +664,7 @@ struct SM90_64x80x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -660,7 +680,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F16F16F16_RS +struct MMA_64x80x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -681,6 +701,7 @@ struct SM90_64x80x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -702,7 +723,7 @@ struct SM90_64x80x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -717,7 +738,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F16F16F16_SS +struct MMA_64x96x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -736,6 +757,7 @@ struct SM90_64x96x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -758,7 +780,7 @@ struct SM90_64x96x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -772,7 +794,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F16F16F16_RS +struct MMA_64x96x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -794,6 +816,7 @@ struct SM90_64x96x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -816,7 +839,7 @@ struct SM90_64x96x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -831,7 +854,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F16F16F16_SS +struct MMA_64x112x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -851,6 +874,7 @@ struct SM90_64x112x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -875,7 +899,7 @@ struct SM90_64x112x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -891,7 +915,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F16F16F16_RS +struct MMA_64x112x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -914,6 +938,7 @@ struct SM90_64x112x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -938,7 +963,7 @@ struct SM90_64x112x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -953,7 +978,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F16F16F16_SS +struct MMA_64x128x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -974,6 +999,7 @@ struct SM90_64x128x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -999,7 +1025,7 @@ struct SM90_64x128x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1013,7 +1039,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F16F16F16_RS +struct MMA_64x128x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1037,6 +1063,7 @@ struct SM90_64x128x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1062,7 +1089,7 @@ struct SM90_64x128x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1077,7 +1104,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F16F16F16_SS +struct MMA_64x144x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1099,6 +1126,7 @@ struct SM90_64x144x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1126,7 +1154,7 @@ struct SM90_64x144x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1142,7 +1170,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F16F16F16_RS +struct MMA_64x144x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1167,6 +1195,7 @@ struct SM90_64x144x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1194,7 +1223,7 @@ struct SM90_64x144x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1210,7 +1239,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F16F16F16_SS +struct MMA_64x160x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1233,6 +1262,7 @@ struct SM90_64x160x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1261,7 +1291,7 @@ struct SM90_64x160x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1277,7 +1307,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F16F16F16_RS +struct MMA_64x160x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1303,6 +1333,7 @@ struct SM90_64x160x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1331,7 +1362,7 @@ struct SM90_64x160x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1347,7 +1378,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F16F16F16_SS +struct MMA_64x176x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1371,6 +1402,7 @@ struct SM90_64x176x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1401,7 +1433,7 @@ struct SM90_64x176x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1417,7 +1449,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F16F16F16_RS +struct MMA_64x176x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1444,6 +1476,7 @@ struct SM90_64x176x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1474,7 +1507,7 @@ struct SM90_64x176x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1489,7 +1522,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F16F16F16_SS +struct MMA_64x192x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1514,6 +1547,7 @@ struct SM90_64x192x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1545,7 +1579,7 @@ struct SM90_64x192x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1559,7 +1593,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F16F16F16_RS +struct MMA_64x192x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1587,6 +1621,7 @@ struct SM90_64x192x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1618,7 +1653,7 @@ struct SM90_64x192x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1633,7 +1668,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F16F16F16_SS +struct MMA_64x208x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1659,6 +1694,7 @@ struct SM90_64x208x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1692,7 +1728,7 @@ struct SM90_64x208x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1708,7 +1744,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F16F16F16_RS +struct MMA_64x208x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1737,6 +1773,7 @@ struct SM90_64x208x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1770,7 +1807,7 @@ struct SM90_64x208x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1786,7 +1823,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F16F16F16_SS +struct MMA_64x224x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1813,6 +1850,7 @@ struct SM90_64x224x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1847,7 +1885,7 @@ struct SM90_64x224x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1863,7 +1901,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F16F16F16_RS +struct MMA_64x224x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -1893,6 +1931,7 @@ struct SM90_64x224x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1927,7 +1966,7 @@ struct SM90_64x224x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1943,7 +1982,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F16F16F16_SS +struct MMA_64x240x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1971,6 +2010,7 @@ struct SM90_64x240x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2007,7 +2047,7 @@ struct SM90_64x240x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2023,7 +2063,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F16F16F16_RS +struct MMA_64x240x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2054,6 +2094,7 @@ struct SM90_64x240x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2090,7 +2131,7 @@ struct SM90_64x240x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2105,7 +2146,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F16F16F16_SS +struct MMA_64x256x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2134,6 +2175,7 @@ struct SM90_64x256x16_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2171,7 +2213,7 @@ struct SM90_64x256x16_F16F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2185,7 +2227,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F16F16F16_RS +struct MMA_64x256x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2217,6 +2259,7 @@ struct SM90_64x256x16_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2254,7 +2297,7 @@ struct SM90_64x256x16_F16F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2268,7 +2311,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F32F16F16_SS +struct MMA_64x8x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2282,6 +2325,7 @@ struct SM90_64x8x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2297,7 +2341,7 @@ struct SM90_64x8x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2311,7 +2355,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F32F16F16_RS +struct MMA_64x8x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2328,6 +2372,7 @@ struct SM90_64x8x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2343,7 +2388,7 @@ struct SM90_64x8x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2357,7 +2402,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F32F16F16_SS +struct MMA_64x16x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2372,6 +2417,7 @@ struct SM90_64x16x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2388,7 +2434,7 @@ struct SM90_64x16x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2402,7 +2448,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F32F16F16_RS +struct MMA_64x16x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2420,6 +2466,7 @@ struct SM90_64x16x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2436,7 +2483,7 @@ struct SM90_64x16x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2450,7 +2497,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F32F16F16_SS +struct MMA_64x32x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2467,6 +2514,7 @@ struct SM90_64x32x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2486,7 +2534,7 @@ struct SM90_64x32x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2500,7 +2548,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F32F16F16_RS +struct MMA_64x32x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2520,6 +2568,7 @@ struct SM90_64x32x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2539,7 +2588,7 @@ struct SM90_64x32x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2554,7 +2603,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F32F16F16_SS +struct MMA_64x48x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2573,6 +2622,7 @@ struct SM90_64x48x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2595,7 +2645,7 @@ struct SM90_64x48x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2611,7 +2661,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F32F16F16_RS +struct MMA_64x48x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2633,6 +2683,7 @@ struct SM90_64x48x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2655,7 +2706,7 @@ struct SM90_64x48x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2670,7 +2721,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F32F16F16_SS +struct MMA_64x64x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2691,6 +2742,7 @@ struct SM90_64x64x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2716,7 +2768,7 @@ struct SM90_64x64x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2730,7 +2782,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F32F16F16_RS +struct MMA_64x64x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2754,6 +2806,7 @@ struct SM90_64x64x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2779,7 +2832,7 @@ struct SM90_64x64x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2794,7 +2847,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F32F16F16_SS +struct MMA_64x80x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2817,6 +2870,7 @@ struct SM90_64x80x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2845,7 +2899,7 @@ struct SM90_64x80x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2861,7 +2915,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F32F16F16_RS +struct MMA_64x80x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2887,6 +2941,7 @@ struct SM90_64x80x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2915,7 +2970,7 @@ struct SM90_64x80x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2930,7 +2985,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F32F16F16_SS +struct MMA_64x96x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2955,6 +3010,7 @@ struct SM90_64x96x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -2986,7 +3042,7 @@ struct SM90_64x96x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3000,7 +3056,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F32F16F16_RS +struct MMA_64x96x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3028,6 +3084,7 @@ struct SM90_64x96x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3059,7 +3116,7 @@ struct SM90_64x96x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3074,7 +3131,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F32F16F16_SS +struct MMA_64x112x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3101,6 +3158,7 @@ struct SM90_64x112x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3135,7 +3193,7 @@ struct SM90_64x112x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3151,7 +3209,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F32F16F16_RS +struct MMA_64x112x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3181,6 +3239,7 @@ struct SM90_64x112x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3215,7 +3274,7 @@ struct SM90_64x112x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3230,7 +3289,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F32F16F16_SS +struct MMA_64x128x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3259,6 +3318,7 @@ struct SM90_64x128x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3296,7 +3356,7 @@ struct SM90_64x128x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3310,7 +3370,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F32F16F16_RS +struct MMA_64x128x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3342,6 +3402,7 @@ struct SM90_64x128x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3379,7 +3440,7 @@ struct SM90_64x128x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3394,7 +3455,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F32F16F16_SS +struct MMA_64x144x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3425,6 +3486,7 @@ struct SM90_64x144x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3465,7 +3527,7 @@ struct SM90_64x144x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3481,7 +3543,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F32F16F16_RS +struct MMA_64x144x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3515,6 +3577,7 @@ struct SM90_64x144x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3555,7 +3618,7 @@ struct SM90_64x144x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3571,7 +3634,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F32F16F16_SS +struct MMA_64x160x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3604,6 +3667,7 @@ struct SM90_64x160x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3647,7 +3711,7 @@ struct SM90_64x160x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3663,7 +3727,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F32F16F16_RS +struct MMA_64x160x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3699,6 +3763,7 @@ struct SM90_64x160x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3742,7 +3807,7 @@ struct SM90_64x160x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3758,7 +3823,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F32F16F16_SS +struct MMA_64x176x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3793,6 +3858,7 @@ struct SM90_64x176x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3839,7 +3905,7 @@ struct SM90_64x176x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3855,7 +3921,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F32F16F16_RS +struct MMA_64x176x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3893,6 +3959,7 @@ struct SM90_64x176x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -3939,7 +4006,7 @@ struct SM90_64x176x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3954,7 +4021,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F32F16F16_SS +struct MMA_64x192x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3991,6 +4058,7 @@ struct SM90_64x192x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4040,7 +4108,7 @@ struct SM90_64x192x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4054,7 +4122,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F32F16F16_RS +struct MMA_64x192x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -4094,6 +4162,7 @@ struct SM90_64x192x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4143,7 +4212,7 @@ struct SM90_64x192x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4158,7 +4227,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F32F16F16_SS +struct MMA_64x208x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -4197,6 +4266,7 @@ struct SM90_64x208x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4249,7 +4319,7 @@ struct SM90_64x208x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4265,7 +4335,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F32F16F16_RS +struct MMA_64x208x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -4307,6 +4377,7 @@ struct SM90_64x208x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4359,7 +4430,7 @@ struct SM90_64x208x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4375,7 +4446,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F32F16F16_SS +struct MMA_64x224x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -4416,6 +4487,7 @@ struct SM90_64x224x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4471,7 +4543,7 @@ struct SM90_64x224x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4487,7 +4559,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F32F16F16_RS +struct MMA_64x224x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -4531,6 +4603,7 @@ struct SM90_64x224x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4586,7 +4659,7 @@ struct SM90_64x224x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4602,7 +4675,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F32F16F16_SS +struct MMA_64x240x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -4645,6 +4718,7 @@ struct SM90_64x240x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4703,7 +4777,7 @@ struct SM90_64x240x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4719,7 +4793,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F32F16F16_RS +struct MMA_64x240x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -4765,6 +4839,7 @@ struct SM90_64x240x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4823,7 +4898,7 @@ struct SM90_64x240x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4838,7 +4913,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F32F16F16_SS +struct MMA_64x256x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -4883,6 +4958,7 @@ struct SM90_64x256x16_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -4944,7 +5020,7 @@ struct SM90_64x256x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4958,7 +5034,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F32F16F16_RS +struct MMA_64x256x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5006,6 +5082,7 @@ struct SM90_64x256x16_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5067,7 +5144,7 @@ struct SM90_64x256x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5081,7 +5158,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F32BF16BF16_SS +struct MMA_64x8x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5095,6 +5172,7 @@ struct SM90_64x8x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5110,7 +5188,7 @@ struct SM90_64x8x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5124,7 +5202,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x16_F32BF16BF16_RS +struct MMA_64x8x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5141,6 +5219,7 @@ struct SM90_64x8x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5156,7 +5235,7 @@ struct SM90_64x8x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5170,7 +5249,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F32BF16BF16_SS +struct MMA_64x16x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5185,6 +5264,7 @@ struct SM90_64x16x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5201,7 +5281,7 @@ struct SM90_64x16x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5215,7 +5295,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x16_F32BF16BF16_RS +struct MMA_64x16x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5233,6 +5313,7 @@ struct SM90_64x16x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5249,7 +5330,7 @@ struct SM90_64x16x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5263,7 +5344,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F32BF16BF16_SS +struct MMA_64x32x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5280,6 +5361,7 @@ struct SM90_64x32x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5299,7 +5381,7 @@ struct SM90_64x32x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5313,7 +5395,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x16_F32BF16BF16_RS +struct MMA_64x32x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5333,6 +5415,7 @@ struct SM90_64x32x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5352,7 +5435,7 @@ struct SM90_64x32x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5367,7 +5450,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F32BF16BF16_SS +struct MMA_64x48x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5386,6 +5469,7 @@ struct SM90_64x48x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5408,7 +5492,7 @@ struct SM90_64x48x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5424,7 +5508,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x16_F32BF16BF16_RS +struct MMA_64x48x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5446,6 +5530,7 @@ struct SM90_64x48x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5468,7 +5553,7 @@ struct SM90_64x48x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5483,7 +5568,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F32BF16BF16_SS +struct MMA_64x64x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5504,6 +5589,7 @@ struct SM90_64x64x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5529,7 +5615,7 @@ struct SM90_64x64x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5543,7 +5629,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x16_F32BF16BF16_RS +struct MMA_64x64x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5567,6 +5653,7 @@ struct SM90_64x64x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5592,7 +5679,7 @@ struct SM90_64x64x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5607,7 +5694,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F32BF16BF16_SS +struct MMA_64x80x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5630,6 +5717,7 @@ struct SM90_64x80x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5658,7 +5746,7 @@ struct SM90_64x80x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5674,7 +5762,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x16_F32BF16BF16_RS +struct MMA_64x80x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5700,6 +5788,7 @@ struct SM90_64x80x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5728,7 +5817,7 @@ struct SM90_64x80x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5743,7 +5832,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F32BF16BF16_SS +struct MMA_64x96x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5768,6 +5857,7 @@ struct SM90_64x96x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5799,7 +5889,7 @@ struct SM90_64x96x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5813,7 +5903,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x16_F32BF16BF16_RS +struct MMA_64x96x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5841,6 +5931,7 @@ struct SM90_64x96x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5872,7 +5963,7 @@ struct SM90_64x96x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5887,7 +5978,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F32BF16BF16_SS +struct MMA_64x112x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -5914,6 +6005,7 @@ struct SM90_64x112x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -5948,7 +6040,7 @@ struct SM90_64x112x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -5964,7 +6056,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x16_F32BF16BF16_RS +struct MMA_64x112x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -5994,6 +6086,7 @@ struct SM90_64x112x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6028,7 +6121,7 @@ struct SM90_64x112x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6043,7 +6136,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F32BF16BF16_SS +struct MMA_64x128x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6072,6 +6165,7 @@ struct SM90_64x128x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6109,7 +6203,7 @@ struct SM90_64x128x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6123,7 +6217,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x16_F32BF16BF16_RS +struct MMA_64x128x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6155,6 +6249,7 @@ struct SM90_64x128x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6192,7 +6287,7 @@ struct SM90_64x128x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6207,7 +6302,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F32BF16BF16_SS +struct MMA_64x144x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6238,6 +6333,7 @@ struct SM90_64x144x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6278,7 +6374,7 @@ struct SM90_64x144x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6294,7 +6390,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x16_F32BF16BF16_RS +struct MMA_64x144x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6328,6 +6424,7 @@ struct SM90_64x144x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6368,7 +6465,7 @@ struct SM90_64x144x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6384,7 +6481,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F32BF16BF16_SS +struct MMA_64x160x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6417,6 +6514,7 @@ struct SM90_64x160x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6460,7 +6558,7 @@ struct SM90_64x160x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6476,7 +6574,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x16_F32BF16BF16_RS +struct MMA_64x160x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6512,6 +6610,7 @@ struct SM90_64x160x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6555,7 +6654,7 @@ struct SM90_64x160x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6571,7 +6670,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F32BF16BF16_SS +struct MMA_64x176x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6606,6 +6705,7 @@ struct SM90_64x176x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6652,7 +6752,7 @@ struct SM90_64x176x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6668,7 +6768,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x16_F32BF16BF16_RS +struct MMA_64x176x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6706,6 +6806,7 @@ struct SM90_64x176x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6752,7 +6853,7 @@ struct SM90_64x176x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6767,7 +6868,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F32BF16BF16_SS +struct MMA_64x192x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -6804,6 +6905,7 @@ struct SM90_64x192x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6853,7 +6955,7 @@ struct SM90_64x192x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6867,7 +6969,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x16_F32BF16BF16_RS +struct MMA_64x192x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -6907,6 +7009,7 @@ struct SM90_64x192x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -6956,7 +7059,7 @@ struct SM90_64x192x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -6971,7 +7074,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F32BF16BF16_SS +struct MMA_64x208x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7010,6 +7113,7 @@ struct SM90_64x208x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7062,7 +7166,7 @@ struct SM90_64x208x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7078,7 +7182,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x16_F32BF16BF16_RS +struct MMA_64x208x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7120,6 +7224,7 @@ struct SM90_64x208x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7172,7 +7277,7 @@ struct SM90_64x208x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7188,7 +7293,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F32BF16BF16_SS +struct MMA_64x224x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7229,6 +7334,7 @@ struct SM90_64x224x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7284,7 +7390,7 @@ struct SM90_64x224x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7300,7 +7406,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x16_F32BF16BF16_RS +struct MMA_64x224x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7344,6 +7450,7 @@ struct SM90_64x224x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7399,7 +7506,7 @@ struct SM90_64x224x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7415,7 +7522,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F32BF16BF16_SS +struct MMA_64x240x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7458,6 +7565,7 @@ struct SM90_64x240x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7516,7 +7624,7 @@ struct SM90_64x240x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7532,7 +7640,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x16_F32BF16BF16_RS +struct MMA_64x240x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7578,6 +7686,7 @@ struct SM90_64x240x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7636,7 +7745,7 @@ struct SM90_64x240x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7651,7 +7760,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F32BF16BF16_SS +struct MMA_64x256x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7696,6 +7805,7 @@ struct SM90_64x256x16_F32BF16BF16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7757,7 +7867,7 @@ struct SM90_64x256x16_F32BF16BF16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7771,7 +7881,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x16_F32BF16BF16_RS +struct MMA_64x256x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7819,6 +7929,7 @@ struct SM90_64x256x16_F32BF16BF16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7880,7 +7991,7 @@ struct SM90_64x256x16_F32BF16BF16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7892,7 +8003,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x8_F32TF32TF32_SS_TN +struct MMA_64x8x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7906,6 +8017,7 @@ struct SM90_64x8x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7921,7 +8033,7 @@ struct SM90_64x8x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7933,7 +8045,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x8_F32TF32TF32_RS_TN +struct MMA_64x8x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -7947,6 +8059,7 @@ struct SM90_64x8x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -7962,7 +8075,7 @@ struct SM90_64x8x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -7974,7 +8087,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x8_F32TF32TF32_SS_TN +struct MMA_64x16x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -7989,6 +8102,7 @@ struct SM90_64x16x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8005,7 +8119,7 @@ struct SM90_64x16x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8017,7 +8131,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x8_F32TF32TF32_RS_TN +struct MMA_64x16x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8032,6 +8146,7 @@ struct SM90_64x16x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8048,7 +8163,7 @@ struct SM90_64x16x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8060,7 +8175,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x8_F32TF32TF32_SS_TN +struct MMA_64x32x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8077,6 +8192,7 @@ struct SM90_64x32x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8096,7 +8212,7 @@ struct SM90_64x32x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8108,7 +8224,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x8_F32TF32TF32_RS_TN +struct MMA_64x32x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8125,6 +8241,7 @@ struct SM90_64x32x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8144,7 +8261,7 @@ struct SM90_64x32x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8157,7 +8274,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x8_F32TF32TF32_SS_TN +struct MMA_64x48x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8176,6 +8293,7 @@ struct SM90_64x48x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8198,7 +8316,7 @@ struct SM90_64x48x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8212,7 +8330,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x8_F32TF32TF32_RS_TN +struct MMA_64x48x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8231,6 +8349,7 @@ struct SM90_64x48x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8253,7 +8372,7 @@ struct SM90_64x48x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8266,7 +8385,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x8_F32TF32TF32_SS_TN +struct MMA_64x64x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8287,6 +8406,7 @@ struct SM90_64x64x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8312,7 +8432,7 @@ struct SM90_64x64x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8324,7 +8444,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x8_F32TF32TF32_RS_TN +struct MMA_64x64x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8345,6 +8465,7 @@ struct SM90_64x64x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8370,7 +8491,7 @@ struct SM90_64x64x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8383,7 +8504,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x8_F32TF32TF32_SS_TN +struct MMA_64x80x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8406,6 +8527,7 @@ struct SM90_64x80x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8434,7 +8556,7 @@ struct SM90_64x80x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8448,7 +8570,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x8_F32TF32TF32_RS_TN +struct MMA_64x80x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8471,6 +8593,7 @@ struct SM90_64x80x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8499,7 +8622,7 @@ struct SM90_64x80x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8512,7 +8635,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x8_F32TF32TF32_SS_TN +struct MMA_64x96x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8537,6 +8660,7 @@ struct SM90_64x96x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8568,7 +8692,7 @@ struct SM90_64x96x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8580,7 +8704,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x8_F32TF32TF32_RS_TN +struct MMA_64x96x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8605,6 +8729,7 @@ struct SM90_64x96x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8636,7 +8761,7 @@ struct SM90_64x96x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8649,7 +8774,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x8_F32TF32TF32_SS_TN +struct MMA_64x112x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8676,6 +8801,7 @@ struct SM90_64x112x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8710,7 +8836,7 @@ struct SM90_64x112x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8724,7 +8850,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x8_F32TF32TF32_RS_TN +struct MMA_64x112x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8751,6 +8877,7 @@ struct SM90_64x112x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8785,7 +8912,7 @@ struct SM90_64x112x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8798,7 +8925,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x8_F32TF32TF32_SS_TN +struct MMA_64x128x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8827,6 +8954,7 @@ struct SM90_64x128x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8864,7 +8992,7 @@ struct SM90_64x128x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8876,7 +9004,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x8_F32TF32TF32_RS_TN +struct MMA_64x128x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -8905,6 +9033,7 @@ struct SM90_64x128x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -8942,7 +9071,7 @@ struct SM90_64x128x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -8955,7 +9084,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x8_F32TF32TF32_SS_TN +struct MMA_64x144x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -8986,6 +9115,7 @@ struct SM90_64x144x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9026,7 +9156,7 @@ struct SM90_64x144x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9040,7 +9170,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x8_F32TF32TF32_RS_TN +struct MMA_64x144x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9071,6 +9201,7 @@ struct SM90_64x144x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9111,7 +9242,7 @@ struct SM90_64x144x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9125,7 +9256,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x8_F32TF32TF32_SS_TN +struct MMA_64x160x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9158,6 +9289,7 @@ struct SM90_64x160x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9201,7 +9333,7 @@ struct SM90_64x160x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9215,7 +9347,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x8_F32TF32TF32_RS_TN +struct MMA_64x160x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9248,6 +9380,7 @@ struct SM90_64x160x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9291,7 +9424,7 @@ struct SM90_64x160x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9305,7 +9438,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x8_F32TF32TF32_SS_TN +struct MMA_64x176x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9340,6 +9473,7 @@ struct SM90_64x176x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9386,7 +9520,7 @@ struct SM90_64x176x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9400,7 +9534,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x8_F32TF32TF32_RS_TN +struct MMA_64x176x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9435,6 +9569,7 @@ struct SM90_64x176x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9481,7 +9616,7 @@ struct SM90_64x176x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9494,7 +9629,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x8_F32TF32TF32_SS_TN +struct MMA_64x192x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9531,6 +9666,7 @@ struct SM90_64x192x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9580,7 +9716,7 @@ struct SM90_64x192x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9592,7 +9728,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x8_F32TF32TF32_RS_TN +struct MMA_64x192x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9629,6 +9765,7 @@ struct SM90_64x192x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9678,7 +9815,7 @@ struct SM90_64x192x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9691,7 +9828,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x8_F32TF32TF32_SS_TN +struct MMA_64x208x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9730,6 +9867,7 @@ struct SM90_64x208x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9782,7 +9920,7 @@ struct SM90_64x208x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9796,7 +9934,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x8_F32TF32TF32_RS_TN +struct MMA_64x208x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -9835,6 +9973,7 @@ struct SM90_64x208x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9887,7 +10026,7 @@ struct SM90_64x208x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -9901,7 +10040,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x8_F32TF32TF32_SS_TN +struct MMA_64x224x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -9942,6 +10081,7 @@ struct SM90_64x224x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -9997,7 +10137,7 @@ struct SM90_64x224x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10011,7 +10151,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x8_F32TF32TF32_RS_TN +struct MMA_64x224x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -10052,6 +10192,7 @@ struct SM90_64x224x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10107,7 +10248,7 @@ struct SM90_64x224x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10121,7 +10262,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x8_F32TF32TF32_SS_TN +struct MMA_64x240x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10164,6 +10305,7 @@ struct SM90_64x240x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10222,7 +10364,7 @@ struct SM90_64x240x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10236,7 +10378,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x8_F32TF32TF32_RS_TN +struct MMA_64x240x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -10279,6 +10421,7 @@ struct SM90_64x240x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10337,7 +10480,7 @@ struct SM90_64x240x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10350,7 +10493,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x8_F32TF32TF32_SS_TN +struct MMA_64x256x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10395,6 +10538,7 @@ struct SM90_64x256x8_F32TF32TF32_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10456,7 +10600,7 @@ struct SM90_64x256x8_F32TF32TF32_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10468,7 +10612,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x8_F32TF32TF32_RS_TN +struct MMA_64x256x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -10513,6 +10657,7 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10574,7 +10719,7 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10582,7 +10727,7 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*S8 -struct SM90_64x8x32_S32S8S8_SS_TN +struct MMA_64x8x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10596,6 +10741,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10611,7 +10757,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10619,7 +10765,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*S8 -struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x8x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10633,6 +10779,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10648,7 +10795,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10656,7 +10803,7 @@ struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*S8 -struct SM90_64x16x32_S32S8S8_SS_TN +struct MMA_64x16x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10671,6 +10818,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10687,7 +10835,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10695,7 +10843,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*S8 -struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x16x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10710,6 +10858,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10726,7 +10875,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10734,7 +10883,7 @@ struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*S8 -struct SM90_64x32x32_S32S8S8_SS_TN +struct MMA_64x32x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10751,6 +10900,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10770,7 +10920,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10778,7 +10928,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*S8 -struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x32x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10795,6 +10945,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10814,7 +10965,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10823,7 +10974,7 @@ struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*S8 -struct SM90_64x48x32_S32S8S8_SS_TN +struct MMA_64x48x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10842,6 +10993,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10864,7 +11016,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10874,7 +11026,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*S8 -struct SM90_64x48x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x48x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10893,6 +11045,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10915,7 +11068,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10924,7 +11077,7 @@ struct SM90_64x48x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*S8 -struct SM90_64x64x32_S32S8S8_SS_TN +struct MMA_64x64x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10945,6 +11098,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -10970,7 +11124,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -10978,7 +11132,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*S8 -struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x64x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -10999,6 +11153,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11024,7 +11179,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11033,7 +11188,7 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*S8 -struct SM90_64x80x32_S32S8S8_SS_TN +struct MMA_64x80x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11056,6 +11211,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11084,7 +11240,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11094,7 +11250,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*S8 -struct SM90_64x80x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x80x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11117,6 +11273,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11145,7 +11302,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11154,7 +11311,7 @@ struct SM90_64x80x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*S8 -struct SM90_64x96x32_S32S8S8_SS_TN +struct MMA_64x96x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11179,6 +11336,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11210,7 +11368,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11218,7 +11376,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*S8 -struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x96x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11243,6 +11401,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11274,7 +11433,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11283,7 +11442,7 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*S8 -struct SM90_64x112x32_S32S8S8_SS_TN +struct MMA_64x112x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11310,6 +11469,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11344,7 +11504,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11354,7 +11514,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*S8 -struct SM90_64x112x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x112x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11381,6 +11541,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11415,7 +11576,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11424,7 +11585,7 @@ struct SM90_64x112x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*S8 -struct SM90_64x128x32_S32S8S8_SS_TN +struct MMA_64x128x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11453,6 +11614,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11490,7 +11652,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11498,7 +11660,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*S8 -struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x128x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11527,6 +11689,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11564,7 +11727,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11573,7 +11736,7 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*S8 -struct SM90_64x144x32_S32S8S8_SS_TN +struct MMA_64x144x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11604,6 +11767,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11644,7 +11808,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11654,7 +11818,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*S8 -struct SM90_64x144x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x144x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11685,6 +11849,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11725,7 +11890,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11735,7 +11900,7 @@ struct SM90_64x144x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*S8 -struct SM90_64x160x32_S32S8S8_SS_TN +struct MMA_64x160x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11768,6 +11933,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11811,7 +11977,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11821,7 +11987,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*S8 -struct SM90_64x160x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x160x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11854,6 +12020,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11897,7 +12064,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11907,7 +12074,7 @@ struct SM90_64x160x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*S8 -struct SM90_64x176x32_S32S8S8_SS_TN +struct MMA_64x176x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -11942,6 +12109,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -11988,7 +12156,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -11998,7 +12166,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*S8 -struct SM90_64x176x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x176x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12033,6 +12201,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12079,7 +12248,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12088,7 +12257,7 @@ struct SM90_64x176x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*S8 -struct SM90_64x192x32_S32S8S8_SS_TN +struct MMA_64x192x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12125,6 +12294,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12174,7 +12344,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12182,7 +12352,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*S8 -struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x192x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12219,6 +12389,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12268,7 +12439,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12277,7 +12448,7 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*S8 -struct SM90_64x208x32_S32S8S8_SS_TN +struct MMA_64x208x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12316,6 +12487,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12368,7 +12540,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12378,7 +12550,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*S8 -struct SM90_64x208x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x208x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12417,6 +12589,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12469,7 +12642,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12479,7 +12652,7 @@ struct SM90_64x208x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*S8 -struct SM90_64x224x32_S32S8S8_SS_TN +struct MMA_64x224x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12520,6 +12693,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12575,7 +12749,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12585,7 +12759,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*S8 -struct SM90_64x224x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x224x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12626,6 +12800,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12681,7 +12856,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12691,7 +12866,7 @@ struct SM90_64x224x32_S32S8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*S8 -struct SM90_64x240x32_S32S8S8_SS_TN +struct MMA_64x240x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12734,6 +12909,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12792,7 +12968,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12802,7 +12978,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*S8 -struct SM90_64x240x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x240x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12845,6 +13021,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -12903,7 +13080,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -12912,7 +13089,7 @@ struct SM90_64x240x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*S8 -struct SM90_64x256x32_S32S8S8_SS_TN +struct MMA_64x256x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -12957,6 +13134,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13018,7 +13196,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13026,7 +13204,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*S8 -struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE +struct MMA_64x256x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -13071,6 +13249,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13132,7 +13311,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13140,7 +13319,7 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*S8 -struct SM90_64x8x32_S32S8S8_RS_TN +struct MMA_64x8x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13154,6 +13333,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13169,7 +13349,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13177,7 +13357,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*S8 -struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x8x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13191,6 +13371,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13206,7 +13387,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13214,7 +13395,7 @@ struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*S8 -struct SM90_64x16x32_S32S8S8_RS_TN +struct MMA_64x16x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13229,6 +13410,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13245,7 +13427,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13253,7 +13435,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*S8 -struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x16x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13268,6 +13450,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13284,7 +13467,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13292,7 +13475,7 @@ struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*S8 -struct SM90_64x32x32_S32S8S8_RS_TN +struct MMA_64x32x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13309,6 +13492,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13328,7 +13512,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13336,7 +13520,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*S8 -struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x32x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13353,6 +13537,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13372,7 +13557,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13381,7 +13566,7 @@ struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*S8 -struct SM90_64x48x32_S32S8S8_RS_TN +struct MMA_64x48x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13400,6 +13585,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13422,7 +13608,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13432,7 +13618,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*S8 -struct SM90_64x48x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x48x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13451,6 +13637,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13473,7 +13660,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13482,7 +13669,7 @@ struct SM90_64x48x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*S8 -struct SM90_64x64x32_S32S8S8_RS_TN +struct MMA_64x64x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13503,6 +13690,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13528,7 +13716,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13536,7 +13724,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*S8 -struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x64x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13557,6 +13745,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13582,7 +13771,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13591,7 +13780,7 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*S8 -struct SM90_64x80x32_S32S8S8_RS_TN +struct MMA_64x80x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13614,6 +13803,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13642,7 +13832,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13652,7 +13842,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*S8 -struct SM90_64x80x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x80x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13675,6 +13865,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13703,7 +13894,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13712,7 +13903,7 @@ struct SM90_64x80x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*S8 -struct SM90_64x96x32_S32S8S8_RS_TN +struct MMA_64x96x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13737,6 +13928,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13768,7 +13960,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13776,7 +13968,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*S8 -struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x96x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13801,6 +13993,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13832,7 +14025,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13841,7 +14034,7 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*S8 -struct SM90_64x112x32_S32S8S8_RS_TN +struct MMA_64x112x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13868,6 +14061,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13902,7 +14096,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13912,7 +14106,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*S8 -struct SM90_64x112x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x112x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -13939,6 +14133,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -13973,7 +14168,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -13982,7 +14177,7 @@ struct SM90_64x112x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*S8 -struct SM90_64x128x32_S32S8S8_RS_TN +struct MMA_64x128x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14011,6 +14206,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14048,7 +14244,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14056,7 +14252,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*S8 -struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x128x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14085,6 +14281,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14122,7 +14319,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14131,7 +14328,7 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*S8 -struct SM90_64x144x32_S32S8S8_RS_TN +struct MMA_64x144x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14162,6 +14359,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14202,7 +14400,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14212,7 +14410,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*S8 -struct SM90_64x144x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x144x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14243,6 +14441,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14283,7 +14482,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14293,7 +14492,7 @@ struct SM90_64x144x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*S8 -struct SM90_64x160x32_S32S8S8_RS_TN +struct MMA_64x160x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14326,6 +14525,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14369,7 +14569,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14379,7 +14579,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*S8 -struct SM90_64x160x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x160x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14412,6 +14612,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14455,7 +14656,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14465,7 +14666,7 @@ struct SM90_64x160x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*S8 -struct SM90_64x176x32_S32S8S8_RS_TN +struct MMA_64x176x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14500,6 +14701,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14546,7 +14748,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14556,7 +14758,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*S8 -struct SM90_64x176x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x176x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14591,6 +14793,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14637,7 +14840,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14646,7 +14849,7 @@ struct SM90_64x176x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*S8 -struct SM90_64x192x32_S32S8S8_RS_TN +struct MMA_64x192x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14683,6 +14886,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14732,7 +14936,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14740,7 +14944,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*S8 -struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x192x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14777,6 +14981,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14826,7 +15031,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14835,7 +15040,7 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*S8 -struct SM90_64x208x32_S32S8S8_RS_TN +struct MMA_64x208x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14874,6 +15079,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -14926,7 +15132,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -14936,7 +15142,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*S8 -struct SM90_64x208x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x208x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -14975,6 +15181,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15027,7 +15234,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15037,7 +15244,7 @@ struct SM90_64x208x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*S8 -struct SM90_64x224x32_S32S8S8_RS_TN +struct MMA_64x224x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15078,6 +15285,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15133,7 +15341,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15143,7 +15351,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*S8 -struct SM90_64x224x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x224x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15184,6 +15392,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15239,7 +15448,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15249,7 +15458,7 @@ struct SM90_64x224x32_S32S8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*S8 -struct SM90_64x240x32_S32S8S8_RS_TN +struct MMA_64x240x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15292,6 +15501,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15350,7 +15560,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15360,7 +15570,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*S8 -struct SM90_64x240x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x240x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15403,6 +15613,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15461,7 +15672,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15470,7 +15681,7 @@ struct SM90_64x240x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*S8 -struct SM90_64x256x32_S32S8S8_RS_TN +struct MMA_64x256x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15515,6 +15726,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15576,7 +15788,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15584,7 +15796,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*S8 -struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE +struct MMA_64x256x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -15629,6 +15841,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15690,7 +15903,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15698,7 +15911,7 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*U8 -struct SM90_64x8x32_S32S8U8_SS_TN +struct MMA_64x8x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15712,6 +15925,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15727,7 +15941,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15735,7 +15949,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*U8 -struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x8x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15749,6 +15963,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15764,7 +15979,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15772,7 +15987,7 @@ struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*U8 -struct SM90_64x16x32_S32S8U8_SS_TN +struct MMA_64x16x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15787,6 +16002,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15803,7 +16019,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15811,7 +16027,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*U8 -struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x16x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15826,6 +16042,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15842,7 +16059,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15850,7 +16067,7 @@ struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*U8 -struct SM90_64x32x32_S32S8U8_SS_TN +struct MMA_64x32x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15867,6 +16084,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15886,7 +16104,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15894,7 +16112,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*U8 -struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x32x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15911,6 +16129,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15930,7 +16149,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15939,7 +16158,7 @@ struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*U8 -struct SM90_64x48x32_S32S8U8_SS_TN +struct MMA_64x48x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -15958,6 +16177,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -15980,7 +16200,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -15990,7 +16210,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*U8 -struct SM90_64x48x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x48x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16009,6 +16229,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16031,7 +16252,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16040,7 +16261,7 @@ struct SM90_64x48x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*U8 -struct SM90_64x64x32_S32S8U8_SS_TN +struct MMA_64x64x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16061,6 +16282,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16086,7 +16308,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16094,7 +16316,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*U8 -struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x64x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16115,6 +16337,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16140,7 +16363,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16149,7 +16372,7 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*U8 -struct SM90_64x80x32_S32S8U8_SS_TN +struct MMA_64x80x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16172,6 +16395,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16200,7 +16424,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16210,7 +16434,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*U8 -struct SM90_64x80x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x80x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16233,6 +16457,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16261,7 +16486,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16270,7 +16495,7 @@ struct SM90_64x80x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*U8 -struct SM90_64x96x32_S32S8U8_SS_TN +struct MMA_64x96x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16295,6 +16520,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16326,7 +16552,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16334,7 +16560,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*U8 -struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x96x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16359,6 +16585,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16390,7 +16617,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16399,7 +16626,7 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*U8 -struct SM90_64x112x32_S32S8U8_SS_TN +struct MMA_64x112x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16426,6 +16653,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16460,7 +16688,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16470,7 +16698,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*U8 -struct SM90_64x112x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x112x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16497,6 +16725,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16531,7 +16760,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16540,7 +16769,7 @@ struct SM90_64x112x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*U8 -struct SM90_64x128x32_S32S8U8_SS_TN +struct MMA_64x128x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16569,6 +16798,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16606,7 +16836,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16614,7 +16844,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*U8 -struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x128x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16643,6 +16873,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16680,7 +16911,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16689,7 +16920,7 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*U8 -struct SM90_64x144x32_S32S8U8_SS_TN +struct MMA_64x144x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16720,6 +16951,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16760,7 +16992,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16770,7 +17002,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*U8 -struct SM90_64x144x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x144x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16801,6 +17033,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16841,7 +17074,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16851,7 +17084,7 @@ struct SM90_64x144x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*U8 -struct SM90_64x160x32_S32S8U8_SS_TN +struct MMA_64x160x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16884,6 +17117,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -16927,7 +17161,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -16937,7 +17171,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*U8 -struct SM90_64x160x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x160x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -16970,6 +17204,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17013,7 +17248,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17023,7 +17258,7 @@ struct SM90_64x160x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*U8 -struct SM90_64x176x32_S32S8U8_SS_TN +struct MMA_64x176x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17058,6 +17293,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17104,7 +17340,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17114,7 +17350,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*U8 -struct SM90_64x176x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x176x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17149,6 +17385,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17195,7 +17432,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17204,7 +17441,7 @@ struct SM90_64x176x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*U8 -struct SM90_64x192x32_S32S8U8_SS_TN +struct MMA_64x192x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17241,6 +17478,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17290,7 +17528,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17298,7 +17536,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*U8 -struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x192x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17335,6 +17573,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17384,7 +17623,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17393,7 +17632,7 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*U8 -struct SM90_64x208x32_S32S8U8_SS_TN +struct MMA_64x208x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17432,6 +17671,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17484,7 +17724,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17494,7 +17734,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*U8 -struct SM90_64x208x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x208x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17533,6 +17773,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17585,7 +17826,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17595,7 +17836,7 @@ struct SM90_64x208x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*U8 -struct SM90_64x224x32_S32S8U8_SS_TN +struct MMA_64x224x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17636,6 +17877,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17691,7 +17933,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17701,7 +17943,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*U8 -struct SM90_64x224x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x224x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17742,6 +17984,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17797,7 +18040,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17807,7 +18050,7 @@ struct SM90_64x224x32_S32S8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*U8 -struct SM90_64x240x32_S32S8U8_SS_TN +struct MMA_64x240x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17850,6 +18093,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -17908,7 +18152,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -17918,7 +18162,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*U8 -struct SM90_64x240x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x240x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -17961,6 +18205,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18019,7 +18264,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18028,7 +18273,7 @@ struct SM90_64x240x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*U8 -struct SM90_64x256x32_S32S8U8_SS_TN +struct MMA_64x256x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -18073,6 +18318,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18134,7 +18380,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18142,7 +18388,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*U8 -struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE +struct MMA_64x256x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -18187,6 +18433,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18248,7 +18495,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18256,7 +18503,7 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*U8 -struct SM90_64x8x32_S32S8U8_RS_TN +struct MMA_64x8x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18270,6 +18517,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18285,7 +18533,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18293,7 +18541,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=S8*U8 -struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x8x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18307,6 +18555,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18322,7 +18571,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18330,7 +18579,7 @@ struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*U8 -struct SM90_64x16x32_S32S8U8_RS_TN +struct MMA_64x16x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18345,6 +18594,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18361,7 +18611,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18369,7 +18619,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=S8*U8 -struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x16x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18384,6 +18634,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18400,7 +18651,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18408,7 +18659,7 @@ struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*U8 -struct SM90_64x32x32_S32S8U8_RS_TN +struct MMA_64x32x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18425,6 +18676,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18444,7 +18696,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18452,7 +18704,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=S8*U8 -struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x32x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18469,6 +18721,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18488,7 +18741,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18497,7 +18750,7 @@ struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*U8 -struct SM90_64x48x32_S32S8U8_RS_TN +struct MMA_64x48x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18516,6 +18769,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18538,7 +18792,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18548,7 +18802,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=S8*U8 -struct SM90_64x48x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x48x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18567,6 +18821,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18589,7 +18844,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18598,7 +18853,7 @@ struct SM90_64x48x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*U8 -struct SM90_64x64x32_S32S8U8_RS_TN +struct MMA_64x64x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18619,6 +18874,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18644,7 +18900,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18652,7 +18908,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=S8*U8 -struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x64x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18673,6 +18929,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18698,7 +18955,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18707,7 +18964,7 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*U8 -struct SM90_64x80x32_S32S8U8_RS_TN +struct MMA_64x80x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18730,6 +18987,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18758,7 +19016,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18768,7 +19026,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=S8*U8 -struct SM90_64x80x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x80x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18791,6 +19049,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18819,7 +19078,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18828,7 +19087,7 @@ struct SM90_64x80x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*U8 -struct SM90_64x96x32_S32S8U8_RS_TN +struct MMA_64x96x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18853,6 +19112,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18884,7 +19144,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18892,7 +19152,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=S8*U8 -struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x96x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18917,6 +19177,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -18948,7 +19209,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -18957,7 +19218,7 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*U8 -struct SM90_64x112x32_S32S8U8_RS_TN +struct MMA_64x112x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -18984,6 +19245,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19018,7 +19280,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19028,7 +19290,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=S8*U8 -struct SM90_64x112x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x112x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19055,6 +19317,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19089,7 +19352,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19098,7 +19361,7 @@ struct SM90_64x112x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*U8 -struct SM90_64x128x32_S32S8U8_RS_TN +struct MMA_64x128x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19127,6 +19390,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19164,7 +19428,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19172,7 +19436,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=S8*U8 -struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x128x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19201,6 +19465,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19238,7 +19503,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19247,7 +19512,7 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*U8 -struct SM90_64x144x32_S32S8U8_RS_TN +struct MMA_64x144x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19278,6 +19543,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19318,7 +19584,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19328,7 +19594,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=S8*U8 -struct SM90_64x144x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x144x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19359,6 +19625,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19399,7 +19666,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19409,7 +19676,7 @@ struct SM90_64x144x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*U8 -struct SM90_64x160x32_S32S8U8_RS_TN +struct MMA_64x160x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19442,6 +19709,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19485,7 +19753,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19495,7 +19763,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=S8*U8 -struct SM90_64x160x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x160x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19528,6 +19796,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19571,7 +19840,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19581,7 +19850,7 @@ struct SM90_64x160x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*U8 -struct SM90_64x176x32_S32S8U8_RS_TN +struct MMA_64x176x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19616,6 +19885,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19662,7 +19932,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19672,7 +19942,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=S8*U8 -struct SM90_64x176x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x176x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19707,6 +19977,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19753,7 +20024,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19762,7 +20033,7 @@ struct SM90_64x176x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*U8 -struct SM90_64x192x32_S32S8U8_RS_TN +struct MMA_64x192x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19799,6 +20070,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19848,7 +20120,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19856,7 +20128,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=S8*U8 -struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x192x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19893,6 +20165,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -19942,7 +20215,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -19951,7 +20224,7 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*U8 -struct SM90_64x208x32_S32S8U8_RS_TN +struct MMA_64x208x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -19990,6 +20263,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20042,7 +20316,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20052,7 +20326,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=S8*U8 -struct SM90_64x208x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x208x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20091,6 +20365,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20143,7 +20418,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20153,7 +20428,7 @@ struct SM90_64x208x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*U8 -struct SM90_64x224x32_S32S8U8_RS_TN +struct MMA_64x224x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20194,6 +20469,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20249,7 +20525,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20259,7 +20535,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=S8*U8 -struct SM90_64x224x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x224x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20300,6 +20576,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20355,7 +20632,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20365,7 +20642,7 @@ struct SM90_64x224x32_S32S8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*U8 -struct SM90_64x240x32_S32S8U8_RS_TN +struct MMA_64x240x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20408,6 +20685,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20466,7 +20744,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20476,7 +20754,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=S8*U8 -struct SM90_64x240x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x240x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20519,6 +20797,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20577,7 +20856,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20586,7 +20865,7 @@ struct SM90_64x240x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*U8 -struct SM90_64x256x32_S32S8U8_RS_TN +struct MMA_64x256x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20631,6 +20910,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20692,7 +20972,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20700,7 +20980,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=S8*U8 -struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE +struct MMA_64x256x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -20745,6 +21025,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20806,7 +21087,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20814,7 +21095,7 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*S8 -struct SM90_64x8x32_S32U8S8_SS_TN +struct MMA_64x8x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20828,6 +21109,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20843,7 +21125,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20851,7 +21133,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*S8 -struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x8x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20865,6 +21147,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20880,7 +21163,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20888,7 +21171,7 @@ struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*S8 -struct SM90_64x16x32_S32U8S8_SS_TN +struct MMA_64x16x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20903,6 +21186,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20919,7 +21203,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20927,7 +21211,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*S8 -struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x16x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20942,6 +21226,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -20958,7 +21243,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -20966,7 +21251,7 @@ struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*S8 -struct SM90_64x32x32_S32U8S8_SS_TN +struct MMA_64x32x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -20983,6 +21268,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21002,7 +21288,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21010,7 +21296,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*S8 -struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x32x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21027,6 +21313,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21046,7 +21333,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21055,7 +21342,7 @@ struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*S8 -struct SM90_64x48x32_S32U8S8_SS_TN +struct MMA_64x48x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21074,6 +21361,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21096,7 +21384,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21106,7 +21394,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*S8 -struct SM90_64x48x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x48x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21125,6 +21413,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21147,7 +21436,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21156,7 +21445,7 @@ struct SM90_64x48x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*S8 -struct SM90_64x64x32_S32U8S8_SS_TN +struct MMA_64x64x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21177,6 +21466,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21202,7 +21492,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21210,7 +21500,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*S8 -struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x64x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21231,6 +21521,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21256,7 +21547,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21265,7 +21556,7 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*S8 -struct SM90_64x80x32_S32U8S8_SS_TN +struct MMA_64x80x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21288,6 +21579,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21316,7 +21608,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21326,7 +21618,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*S8 -struct SM90_64x80x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x80x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21349,6 +21641,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21377,7 +21670,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21386,7 +21679,7 @@ struct SM90_64x80x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*S8 -struct SM90_64x96x32_S32U8S8_SS_TN +struct MMA_64x96x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21411,6 +21704,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21442,7 +21736,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21450,7 +21744,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*S8 -struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x96x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21475,6 +21769,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21506,7 +21801,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21515,7 +21810,7 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*S8 -struct SM90_64x112x32_S32U8S8_SS_TN +struct MMA_64x112x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21542,6 +21837,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21576,7 +21872,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21586,7 +21882,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*S8 -struct SM90_64x112x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x112x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21613,6 +21909,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21647,7 +21944,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21656,7 +21953,7 @@ struct SM90_64x112x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*S8 -struct SM90_64x128x32_S32U8S8_SS_TN +struct MMA_64x128x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21685,6 +21982,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21722,7 +22020,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21730,7 +22028,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*S8 -struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x128x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21759,6 +22057,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21796,7 +22095,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21805,7 +22104,7 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*S8 -struct SM90_64x144x32_S32U8S8_SS_TN +struct MMA_64x144x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21836,6 +22135,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21876,7 +22176,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21886,7 +22186,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*S8 -struct SM90_64x144x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x144x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -21917,6 +22217,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -21957,7 +22258,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -21967,7 +22268,7 @@ struct SM90_64x144x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*S8 -struct SM90_64x160x32_S32U8S8_SS_TN +struct MMA_64x160x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22000,6 +22301,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22043,7 +22345,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22053,7 +22355,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*S8 -struct SM90_64x160x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x160x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22086,6 +22388,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22129,7 +22432,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22139,7 +22442,7 @@ struct SM90_64x160x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*S8 -struct SM90_64x176x32_S32U8S8_SS_TN +struct MMA_64x176x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22174,6 +22477,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22220,7 +22524,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22230,7 +22534,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*S8 -struct SM90_64x176x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x176x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22265,6 +22569,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22311,7 +22616,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22320,7 +22625,7 @@ struct SM90_64x176x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*S8 -struct SM90_64x192x32_S32U8S8_SS_TN +struct MMA_64x192x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22357,6 +22662,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22406,7 +22712,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22414,7 +22720,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*S8 -struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x192x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22451,6 +22757,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22500,7 +22807,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22509,7 +22816,7 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*S8 -struct SM90_64x208x32_S32U8S8_SS_TN +struct MMA_64x208x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22548,6 +22855,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22600,7 +22908,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22610,7 +22918,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*S8 -struct SM90_64x208x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x208x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22649,6 +22957,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22701,7 +23010,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22711,7 +23020,7 @@ struct SM90_64x208x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*S8 -struct SM90_64x224x32_S32U8S8_SS_TN +struct MMA_64x224x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22752,6 +23061,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22807,7 +23117,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22817,7 +23127,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*S8 -struct SM90_64x224x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x224x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22858,6 +23168,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -22913,7 +23224,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -22923,7 +23234,7 @@ struct SM90_64x224x32_S32U8S8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*S8 -struct SM90_64x240x32_S32U8S8_SS_TN +struct MMA_64x240x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22966,6 +23277,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23024,7 +23336,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23034,7 +23346,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*S8 -struct SM90_64x240x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x240x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -23077,6 +23389,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23135,7 +23448,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23144,7 +23457,7 @@ struct SM90_64x240x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*S8 -struct SM90_64x256x32_S32U8S8_SS_TN +struct MMA_64x256x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -23189,6 +23502,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23250,7 +23564,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23258,7 +23572,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*S8 -struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE +struct MMA_64x256x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -23303,6 +23617,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23364,7 +23679,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23372,7 +23687,7 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*S8 -struct SM90_64x8x32_S32U8S8_RS_TN +struct MMA_64x8x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23386,6 +23701,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23401,7 +23717,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23409,7 +23725,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*S8 -struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x8x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23423,6 +23739,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23438,7 +23755,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23446,7 +23763,7 @@ struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*S8 -struct SM90_64x16x32_S32U8S8_RS_TN +struct MMA_64x16x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23461,6 +23778,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23477,7 +23795,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23485,7 +23803,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*S8 -struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x16x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23500,6 +23818,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23516,7 +23835,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23524,7 +23843,7 @@ struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*S8 -struct SM90_64x32x32_S32U8S8_RS_TN +struct MMA_64x32x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23541,6 +23860,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23560,7 +23880,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23568,7 +23888,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*S8 -struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x32x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23585,6 +23905,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23604,7 +23925,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23613,7 +23934,7 @@ struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*S8 -struct SM90_64x48x32_S32U8S8_RS_TN +struct MMA_64x48x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23632,6 +23953,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23654,7 +23976,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23664,7 +23986,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*S8 -struct SM90_64x48x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x48x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23683,6 +24005,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23705,7 +24028,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23714,7 +24037,7 @@ struct SM90_64x48x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*S8 -struct SM90_64x64x32_S32U8S8_RS_TN +struct MMA_64x64x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23735,6 +24058,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23760,7 +24084,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23768,7 +24092,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*S8 -struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x64x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23789,6 +24113,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23814,7 +24139,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23823,7 +24148,7 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*S8 -struct SM90_64x80x32_S32U8S8_RS_TN +struct MMA_64x80x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23846,6 +24171,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23874,7 +24200,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23884,7 +24210,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*S8 -struct SM90_64x80x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x80x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23907,6 +24233,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -23935,7 +24262,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -23944,7 +24271,7 @@ struct SM90_64x80x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*S8 -struct SM90_64x96x32_S32U8S8_RS_TN +struct MMA_64x96x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -23969,6 +24296,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24000,7 +24328,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24008,7 +24336,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*S8 -struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x96x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24033,6 +24361,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24064,7 +24393,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24073,7 +24402,7 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*S8 -struct SM90_64x112x32_S32U8S8_RS_TN +struct MMA_64x112x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24100,6 +24429,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24134,7 +24464,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24144,7 +24474,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*S8 -struct SM90_64x112x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x112x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24171,6 +24501,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24205,7 +24536,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24214,7 +24545,7 @@ struct SM90_64x112x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*S8 -struct SM90_64x128x32_S32U8S8_RS_TN +struct MMA_64x128x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24243,6 +24574,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24280,7 +24612,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24288,7 +24620,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*S8 -struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x128x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24317,6 +24649,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24354,7 +24687,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24363,7 +24696,7 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*S8 -struct SM90_64x144x32_S32U8S8_RS_TN +struct MMA_64x144x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24394,6 +24727,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24434,7 +24768,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24444,7 +24778,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*S8 -struct SM90_64x144x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x144x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24475,6 +24809,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24515,7 +24850,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24525,7 +24860,7 @@ struct SM90_64x144x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*S8 -struct SM90_64x160x32_S32U8S8_RS_TN +struct MMA_64x160x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24558,6 +24893,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24601,7 +24937,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24611,7 +24947,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*S8 -struct SM90_64x160x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x160x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24644,6 +24980,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24687,7 +25024,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24697,7 +25034,7 @@ struct SM90_64x160x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*S8 -struct SM90_64x176x32_S32U8S8_RS_TN +struct MMA_64x176x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24732,6 +25069,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24778,7 +25116,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24788,7 +25126,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*S8 -struct SM90_64x176x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x176x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24823,6 +25161,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24869,7 +25208,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24878,7 +25217,7 @@ struct SM90_64x176x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*S8 -struct SM90_64x192x32_S32U8S8_RS_TN +struct MMA_64x192x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -24915,6 +25254,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -24964,7 +25304,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -24972,7 +25312,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*S8 -struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x192x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25009,6 +25349,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25058,7 +25399,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25067,7 +25408,7 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*S8 -struct SM90_64x208x32_S32U8S8_RS_TN +struct MMA_64x208x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25106,6 +25447,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25158,7 +25500,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25168,7 +25510,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*S8 -struct SM90_64x208x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x208x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25207,6 +25549,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25259,7 +25602,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25269,7 +25612,7 @@ struct SM90_64x208x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*S8 -struct SM90_64x224x32_S32U8S8_RS_TN +struct MMA_64x224x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25310,6 +25653,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25365,7 +25709,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25375,7 +25719,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*S8 -struct SM90_64x224x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x224x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25416,6 +25760,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25471,7 +25816,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25481,7 +25826,7 @@ struct SM90_64x224x32_S32U8S8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*S8 -struct SM90_64x240x32_S32U8S8_RS_TN +struct MMA_64x240x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25524,6 +25869,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25582,7 +25928,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25592,7 +25938,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*S8 -struct SM90_64x240x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x240x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25635,6 +25981,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25693,7 +26040,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25702,7 +26049,7 @@ struct SM90_64x240x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*S8 -struct SM90_64x256x32_S32U8S8_RS_TN +struct MMA_64x256x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25747,6 +26094,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25808,7 +26156,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25816,7 +26164,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*S8 -struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE +struct MMA_64x256x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -25861,6 +26209,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25922,7 +26271,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25930,7 +26279,7 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*U8 -struct SM90_64x8x32_S32U8U8_SS_TN +struct MMA_64x8x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -25944,6 +26293,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25959,7 +26309,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -25967,7 +26317,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*U8 -struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x8x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -25981,6 +26331,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -25996,7 +26347,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26004,7 +26355,7 @@ struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*U8 -struct SM90_64x16x32_S32U8U8_SS_TN +struct MMA_64x16x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26019,6 +26370,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26035,7 +26387,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26043,7 +26395,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*U8 -struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x16x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26058,6 +26410,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26074,7 +26427,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26082,7 +26435,7 @@ struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*U8 -struct SM90_64x32x32_S32U8U8_SS_TN +struct MMA_64x32x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26099,6 +26452,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26118,7 +26472,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26126,7 +26480,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*U8 -struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x32x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26143,6 +26497,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26162,7 +26517,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26171,7 +26526,7 @@ struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*U8 -struct SM90_64x48x32_S32U8U8_SS_TN +struct MMA_64x48x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26190,6 +26545,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26212,7 +26568,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26222,7 +26578,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*U8 -struct SM90_64x48x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x48x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26241,6 +26597,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26263,7 +26620,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26272,7 +26629,7 @@ struct SM90_64x48x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*U8 -struct SM90_64x64x32_S32U8U8_SS_TN +struct MMA_64x64x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26293,6 +26650,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26318,7 +26676,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26326,7 +26684,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*U8 -struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x64x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26347,6 +26705,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26372,7 +26731,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26381,7 +26740,7 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*U8 -struct SM90_64x80x32_S32U8U8_SS_TN +struct MMA_64x80x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26404,6 +26763,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26432,7 +26792,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26442,7 +26802,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*U8 -struct SM90_64x80x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x80x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26465,6 +26825,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26493,7 +26854,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26502,7 +26863,7 @@ struct SM90_64x80x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*U8 -struct SM90_64x96x32_S32U8U8_SS_TN +struct MMA_64x96x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26527,6 +26888,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26558,7 +26920,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26566,7 +26928,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*U8 -struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x96x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26591,6 +26953,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26622,7 +26985,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26631,7 +26994,7 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*U8 -struct SM90_64x112x32_S32U8U8_SS_TN +struct MMA_64x112x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26658,6 +27021,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26692,7 +27056,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26702,7 +27066,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*U8 -struct SM90_64x112x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x112x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26729,6 +27093,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26763,7 +27128,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26772,7 +27137,7 @@ struct SM90_64x112x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*U8 -struct SM90_64x128x32_S32U8U8_SS_TN +struct MMA_64x128x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26801,6 +27166,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26838,7 +27204,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26846,7 +27212,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*U8 -struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x128x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26875,6 +27241,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26912,7 +27279,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -26921,7 +27288,7 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*U8 -struct SM90_64x144x32_S32U8U8_SS_TN +struct MMA_64x144x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -26952,6 +27319,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -26992,7 +27360,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27002,7 +27370,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*U8 -struct SM90_64x144x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x144x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27033,6 +27401,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27073,7 +27442,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27083,7 +27452,7 @@ struct SM90_64x144x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*U8 -struct SM90_64x160x32_S32U8U8_SS_TN +struct MMA_64x160x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27116,6 +27485,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27159,7 +27529,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27169,7 +27539,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*U8 -struct SM90_64x160x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x160x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27202,6 +27572,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27245,7 +27616,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27255,7 +27626,7 @@ struct SM90_64x160x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*U8 -struct SM90_64x176x32_S32U8U8_SS_TN +struct MMA_64x176x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27290,6 +27661,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27336,7 +27708,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27346,7 +27718,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*U8 -struct SM90_64x176x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x176x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27381,6 +27753,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27427,7 +27800,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27436,7 +27809,7 @@ struct SM90_64x176x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*U8 -struct SM90_64x192x32_S32U8U8_SS_TN +struct MMA_64x192x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27473,6 +27846,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27522,7 +27896,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27530,7 +27904,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*U8 -struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x192x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27567,6 +27941,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27616,7 +27991,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27625,7 +28000,7 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*U8 -struct SM90_64x208x32_S32U8U8_SS_TN +struct MMA_64x208x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27664,6 +28039,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27716,7 +28092,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27726,7 +28102,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*U8 -struct SM90_64x208x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x208x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27765,6 +28141,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27817,7 +28194,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27827,7 +28204,7 @@ struct SM90_64x208x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*U8 -struct SM90_64x224x32_S32U8U8_SS_TN +struct MMA_64x224x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27868,6 +28245,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -27923,7 +28301,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -27933,7 +28311,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*U8 -struct SM90_64x224x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x224x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -27974,6 +28352,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28029,7 +28408,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28039,7 +28418,7 @@ struct SM90_64x224x32_S32U8U8_SS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*U8 -struct SM90_64x240x32_S32U8U8_SS_TN +struct MMA_64x240x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -28082,6 +28461,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28140,7 +28520,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28150,7 +28530,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*U8 -struct SM90_64x240x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x240x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -28193,6 +28573,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28251,7 +28632,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28260,7 +28641,7 @@ struct SM90_64x240x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*U8 -struct SM90_64x256x32_S32U8U8_SS_TN +struct MMA_64x256x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -28305,6 +28686,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28366,7 +28748,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28374,7 +28756,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*U8 -struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE +struct MMA_64x256x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -28419,6 +28801,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28480,7 +28863,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28488,7 +28871,7 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*U8 -struct SM90_64x8x32_S32U8U8_RS_TN +struct MMA_64x8x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28502,6 +28885,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28517,7 +28901,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28525,7 +28909,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x32 TN S32+=U8*U8 -struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x8x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28539,6 +28923,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28554,7 +28939,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28562,7 +28947,7 @@ struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*U8 -struct SM90_64x16x32_S32U8U8_RS_TN +struct MMA_64x16x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28577,6 +28962,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28593,7 +28979,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28601,7 +28987,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x32 TN S32+=U8*U8 -struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x16x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28616,6 +29002,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28632,7 +29019,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28640,7 +29027,7 @@ struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*U8 -struct SM90_64x32x32_S32U8U8_RS_TN +struct MMA_64x32x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28657,6 +29044,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28676,7 +29064,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28684,7 +29072,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x32 TN S32+=U8*U8 -struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x32x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28701,6 +29089,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28720,7 +29109,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28729,7 +29118,7 @@ struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*U8 -struct SM90_64x48x32_S32U8U8_RS_TN +struct MMA_64x48x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28748,6 +29137,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28770,7 +29160,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28780,7 +29170,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x48x32 TN S32+=U8*U8 -struct SM90_64x48x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x48x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28799,6 +29189,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28821,7 +29212,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28830,7 +29221,7 @@ struct SM90_64x48x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*U8 -struct SM90_64x64x32_S32U8U8_RS_TN +struct MMA_64x64x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28851,6 +29242,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28876,7 +29268,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28884,7 +29276,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x32 TN S32+=U8*U8 -struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x64x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28905,6 +29297,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28930,7 +29323,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -28939,7 +29332,7 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*U8 -struct SM90_64x80x32_S32U8U8_RS_TN +struct MMA_64x80x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -28962,6 +29355,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -28990,7 +29384,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29000,7 +29394,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x80x32 TN S32+=U8*U8 -struct SM90_64x80x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x80x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29023,6 +29417,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29051,7 +29446,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29060,7 +29455,7 @@ struct SM90_64x80x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*U8 -struct SM90_64x96x32_S32U8U8_RS_TN +struct MMA_64x96x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29085,6 +29480,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29116,7 +29512,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29124,7 +29520,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x32 TN S32+=U8*U8 -struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x96x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29149,6 +29545,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29180,7 +29577,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29189,7 +29586,7 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*U8 -struct SM90_64x112x32_S32U8U8_RS_TN +struct MMA_64x112x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29216,6 +29613,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29250,7 +29648,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29260,7 +29658,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x112x32 TN S32+=U8*U8 -struct SM90_64x112x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x112x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29287,6 +29685,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29321,7 +29720,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29330,7 +29729,7 @@ struct SM90_64x112x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*U8 -struct SM90_64x128x32_S32U8U8_RS_TN +struct MMA_64x128x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29359,6 +29758,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29396,7 +29796,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29404,7 +29804,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x32 TN S32+=U8*U8 -struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x128x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29433,6 +29833,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29470,7 +29871,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29479,7 +29880,7 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*U8 -struct SM90_64x144x32_S32U8U8_RS_TN +struct MMA_64x144x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29510,6 +29911,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29550,7 +29952,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29560,7 +29962,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x144x32 TN S32+=U8*U8 -struct SM90_64x144x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x144x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29591,6 +29993,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29631,7 +30034,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29641,7 +30044,7 @@ struct SM90_64x144x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*U8 -struct SM90_64x160x32_S32U8U8_RS_TN +struct MMA_64x160x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29674,6 +30077,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29717,7 +30121,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29727,7 +30131,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x160x32 TN S32+=U8*U8 -struct SM90_64x160x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x160x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29760,6 +30164,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29803,7 +30208,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29813,7 +30218,7 @@ struct SM90_64x160x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*U8 -struct SM90_64x176x32_S32U8U8_RS_TN +struct MMA_64x176x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29848,6 +30253,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29894,7 +30300,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29904,7 +30310,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x176x32 TN S32+=U8*U8 -struct SM90_64x176x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x176x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -29939,6 +30345,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -29985,7 +30392,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -29994,7 +30401,7 @@ struct SM90_64x176x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*U8 -struct SM90_64x192x32_S32U8U8_RS_TN +struct MMA_64x192x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30031,6 +30438,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30080,7 +30488,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30088,7 +30496,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x32 TN S32+=U8*U8 -struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x192x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30125,6 +30533,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30174,7 +30583,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30183,7 +30592,7 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*U8 -struct SM90_64x208x32_S32U8U8_RS_TN +struct MMA_64x208x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30222,6 +30631,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30274,7 +30684,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30284,7 +30694,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x208x32 TN S32+=U8*U8 -struct SM90_64x208x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x208x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30323,6 +30733,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30375,7 +30786,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30385,7 +30796,7 @@ struct SM90_64x208x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*U8 -struct SM90_64x224x32_S32U8U8_RS_TN +struct MMA_64x224x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30426,6 +30837,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30481,7 +30893,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30491,7 +30903,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x224x32 TN S32+=U8*U8 -struct SM90_64x224x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x224x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30532,6 +30944,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30587,7 +31000,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30597,7 +31010,7 @@ struct SM90_64x224x32_S32U8U8_RS_TN_SATURATE #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*U8 -struct SM90_64x240x32_S32U8U8_RS_TN +struct MMA_64x240x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30640,6 +31053,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30698,7 +31112,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30708,7 +31122,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) // GMMA 64x240x32 TN S32+=U8*U8 -struct SM90_64x240x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x240x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30751,6 +31165,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30809,7 +31224,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30818,7 +31233,7 @@ struct SM90_64x240x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*U8 -struct SM90_64x256x32_S32U8U8_RS_TN +struct MMA_64x256x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30863,6 +31278,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -30924,7 +31340,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -30932,7 +31348,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x32 TN S32+=U8*U8 -struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE +struct MMA_64x256x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -30977,6 +31393,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31038,7 +31455,7 @@ struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE "l"(desc_b), "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31050,7 +31467,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E4M3E4M3_SS_TN +struct MMA_64x8x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31064,6 +31481,7 @@ struct SM90_64x8x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31079,7 +31497,7 @@ struct SM90_64x8x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31091,7 +31509,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E4M3E4M3_RS_TN +struct MMA_64x8x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31105,6 +31523,7 @@ struct SM90_64x8x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31120,7 +31539,7 @@ struct SM90_64x8x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31132,7 +31551,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E4M3E4M3_SS_TN +struct MMA_64x8x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31146,6 +31565,7 @@ struct SM90_64x8x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31161,7 +31581,7 @@ struct SM90_64x8x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31173,7 +31593,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E4M3E4M3_RS_TN +struct MMA_64x8x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31187,6 +31607,7 @@ struct SM90_64x8x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31202,7 +31623,7 @@ struct SM90_64x8x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31214,7 +31635,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E4M3E4M3_SS_TN +struct MMA_64x16x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31228,6 +31649,7 @@ struct SM90_64x16x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31243,7 +31665,7 @@ struct SM90_64x16x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31255,7 +31677,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E4M3E4M3_RS_TN +struct MMA_64x16x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31269,6 +31691,7 @@ struct SM90_64x16x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31284,7 +31707,7 @@ struct SM90_64x16x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31296,7 +31719,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E4M3E4M3_SS_TN +struct MMA_64x16x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31311,6 +31734,7 @@ struct SM90_64x16x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31327,7 +31751,7 @@ struct SM90_64x16x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31339,7 +31763,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E4M3E4M3_RS_TN +struct MMA_64x16x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31354,6 +31778,7 @@ struct SM90_64x16x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31370,7 +31795,7 @@ struct SM90_64x16x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31382,7 +31807,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E4M3E4M3_SS_TN +struct MMA_64x32x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31397,6 +31822,7 @@ struct SM90_64x32x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31413,7 +31839,7 @@ struct SM90_64x32x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31425,7 +31851,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E4M3E4M3_RS_TN +struct MMA_64x32x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31440,6 +31866,7 @@ struct SM90_64x32x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31456,7 +31883,7 @@ struct SM90_64x32x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31468,7 +31895,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E4M3E4M3_SS_TN +struct MMA_64x32x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31485,6 +31912,7 @@ struct SM90_64x32x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31504,7 +31932,7 @@ struct SM90_64x32x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31516,7 +31944,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E4M3E4M3_RS_TN +struct MMA_64x32x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31533,6 +31961,7 @@ struct SM90_64x32x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31552,7 +31981,7 @@ struct SM90_64x32x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31565,7 +31994,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E4M3E4M3_SS_TN +struct MMA_64x48x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31581,6 +32010,7 @@ struct SM90_64x48x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31599,7 +32029,7 @@ struct SM90_64x48x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31613,7 +32043,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E4M3E4M3_RS_TN +struct MMA_64x48x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31629,6 +32059,7 @@ struct SM90_64x48x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31647,7 +32078,7 @@ struct SM90_64x48x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31661,7 +32092,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E4M3E4M3_SS_TN +struct MMA_64x48x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31680,6 +32111,7 @@ struct SM90_64x48x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31702,7 +32134,7 @@ struct SM90_64x48x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31716,7 +32148,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E4M3E4M3_RS_TN +struct MMA_64x48x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31735,6 +32167,7 @@ struct SM90_64x48x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31757,7 +32190,7 @@ struct SM90_64x48x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31770,7 +32203,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E4M3E4M3_SS_TN +struct MMA_64x64x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31787,6 +32220,7 @@ struct SM90_64x64x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31806,7 +32240,7 @@ struct SM90_64x64x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31818,7 +32252,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E4M3E4M3_RS_TN +struct MMA_64x64x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31835,6 +32269,7 @@ struct SM90_64x64x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31854,7 +32289,7 @@ struct SM90_64x64x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31866,7 +32301,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E4M3E4M3_SS_TN +struct MMA_64x64x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -31887,6 +32322,7 @@ struct SM90_64x64x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31912,7 +32348,7 @@ struct SM90_64x64x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31924,7 +32360,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E4M3E4M3_RS_TN +struct MMA_64x64x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -31945,6 +32381,7 @@ struct SM90_64x64x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -31970,7 +32407,7 @@ struct SM90_64x64x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -31983,7 +32420,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E4M3E4M3_SS_TN +struct MMA_64x80x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32001,6 +32438,7 @@ struct SM90_64x80x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32022,7 +32460,7 @@ struct SM90_64x80x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32036,7 +32474,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E4M3E4M3_RS_TN +struct MMA_64x80x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32054,6 +32492,7 @@ struct SM90_64x80x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32075,7 +32514,7 @@ struct SM90_64x80x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32089,7 +32528,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E4M3E4M3_SS_TN +struct MMA_64x80x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32112,6 +32551,7 @@ struct SM90_64x80x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32140,7 +32580,7 @@ struct SM90_64x80x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32154,7 +32594,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E4M3E4M3_RS_TN +struct MMA_64x80x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32177,6 +32617,7 @@ struct SM90_64x80x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32205,7 +32646,7 @@ struct SM90_64x80x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32218,7 +32659,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E4M3E4M3_SS_TN +struct MMA_64x96x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32237,6 +32678,7 @@ struct SM90_64x96x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32259,7 +32701,7 @@ struct SM90_64x96x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32271,7 +32713,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E4M3E4M3_RS_TN +struct MMA_64x96x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32290,6 +32732,7 @@ struct SM90_64x96x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32312,7 +32755,7 @@ struct SM90_64x96x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32324,7 +32767,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E4M3E4M3_SS_TN +struct MMA_64x96x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32349,6 +32792,7 @@ struct SM90_64x96x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32380,7 +32824,7 @@ struct SM90_64x96x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32392,7 +32836,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E4M3E4M3_RS_TN +struct MMA_64x96x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32417,6 +32861,7 @@ struct SM90_64x96x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32448,7 +32893,7 @@ struct SM90_64x96x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32461,7 +32906,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E4M3E4M3_SS_TN +struct MMA_64x112x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32481,6 +32926,7 @@ struct SM90_64x112x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32505,7 +32951,7 @@ struct SM90_64x112x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32519,7 +32965,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E4M3E4M3_RS_TN +struct MMA_64x112x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32539,6 +32985,7 @@ struct SM90_64x112x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32563,7 +33010,7 @@ struct SM90_64x112x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32577,7 +33024,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E4M3E4M3_SS_TN +struct MMA_64x112x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32604,6 +33051,7 @@ struct SM90_64x112x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32638,7 +33086,7 @@ struct SM90_64x112x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32652,7 +33100,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E4M3E4M3_RS_TN +struct MMA_64x112x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32679,6 +33127,7 @@ struct SM90_64x112x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32713,7 +33162,7 @@ struct SM90_64x112x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32726,7 +33175,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E4M3E4M3_SS_TN +struct MMA_64x128x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32747,6 +33196,7 @@ struct SM90_64x128x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32772,7 +33222,7 @@ struct SM90_64x128x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32784,7 +33234,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E4M3E4M3_RS_TN +struct MMA_64x128x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32805,6 +33255,7 @@ struct SM90_64x128x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32830,7 +33281,7 @@ struct SM90_64x128x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32842,7 +33293,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E4M3E4M3_SS_TN +struct MMA_64x128x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -32871,6 +33322,7 @@ struct SM90_64x128x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32908,7 +33360,7 @@ struct SM90_64x128x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32920,7 +33372,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E4M3E4M3_RS_TN +struct MMA_64x128x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -32949,6 +33401,7 @@ struct SM90_64x128x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -32986,7 +33439,7 @@ struct SM90_64x128x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -32999,7 +33452,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E4M3E4M3_SS_TN +struct MMA_64x144x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33021,6 +33474,7 @@ struct SM90_64x144x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33048,7 +33502,7 @@ struct SM90_64x144x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33062,7 +33516,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E4M3E4M3_RS_TN +struct MMA_64x144x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33084,6 +33538,7 @@ struct SM90_64x144x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33111,7 +33566,7 @@ struct SM90_64x144x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33125,7 +33580,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E4M3E4M3_SS_TN +struct MMA_64x144x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33156,6 +33611,7 @@ struct SM90_64x144x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33196,7 +33652,7 @@ struct SM90_64x144x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33210,7 +33666,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E4M3E4M3_RS_TN +struct MMA_64x144x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33241,6 +33697,7 @@ struct SM90_64x144x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33281,7 +33738,7 @@ struct SM90_64x144x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33295,7 +33752,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E4M3E4M3_SS_TN +struct MMA_64x160x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33318,6 +33775,7 @@ struct SM90_64x160x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33346,7 +33804,7 @@ struct SM90_64x160x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33360,7 +33818,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E4M3E4M3_RS_TN +struct MMA_64x160x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33383,6 +33841,7 @@ struct SM90_64x160x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33411,7 +33870,7 @@ struct SM90_64x160x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33425,7 +33884,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E4M3E4M3_SS_TN +struct MMA_64x160x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33458,6 +33917,7 @@ struct SM90_64x160x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33501,7 +33961,7 @@ struct SM90_64x160x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33515,7 +33975,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E4M3E4M3_RS_TN +struct MMA_64x160x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33548,6 +34008,7 @@ struct SM90_64x160x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33591,7 +34052,7 @@ struct SM90_64x160x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33605,7 +34066,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E4M3E4M3_SS_TN +struct MMA_64x176x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33629,6 +34090,7 @@ struct SM90_64x176x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33659,7 +34121,7 @@ struct SM90_64x176x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33673,7 +34135,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E4M3E4M3_RS_TN +struct MMA_64x176x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33697,6 +34159,7 @@ struct SM90_64x176x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33727,7 +34190,7 @@ struct SM90_64x176x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33741,7 +34204,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E4M3E4M3_SS_TN +struct MMA_64x176x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33776,6 +34239,7 @@ struct SM90_64x176x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33822,7 +34286,7 @@ struct SM90_64x176x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33836,7 +34300,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E4M3E4M3_RS_TN +struct MMA_64x176x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -33871,6 +34335,7 @@ struct SM90_64x176x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33917,7 +34382,7 @@ struct SM90_64x176x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33930,7 +34395,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E4M3E4M3_SS_TN +struct MMA_64x192x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33955,6 +34420,7 @@ struct SM90_64x192x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -33986,7 +34452,7 @@ struct SM90_64x192x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -33998,7 +34464,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E4M3E4M3_RS_TN +struct MMA_64x192x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34023,6 +34489,7 @@ struct SM90_64x192x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34054,7 +34521,7 @@ struct SM90_64x192x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34066,7 +34533,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E4M3E4M3_SS_TN +struct MMA_64x192x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34103,6 +34570,7 @@ struct SM90_64x192x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34152,7 +34620,7 @@ struct SM90_64x192x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34164,7 +34632,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E4M3E4M3_RS_TN +struct MMA_64x192x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34201,6 +34669,7 @@ struct SM90_64x192x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34250,7 +34719,7 @@ struct SM90_64x192x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34263,7 +34732,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E4M3E4M3_SS_TN +struct MMA_64x208x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34289,6 +34758,7 @@ struct SM90_64x208x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34322,7 +34792,7 @@ struct SM90_64x208x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34336,7 +34806,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E4M3E4M3_RS_TN +struct MMA_64x208x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34362,6 +34832,7 @@ struct SM90_64x208x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34395,7 +34866,7 @@ struct SM90_64x208x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34409,7 +34880,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E4M3E4M3_SS_TN +struct MMA_64x208x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34448,6 +34919,7 @@ struct SM90_64x208x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34500,7 +34972,7 @@ struct SM90_64x208x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34514,7 +34986,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E4M3E4M3_RS_TN +struct MMA_64x208x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34553,6 +35025,7 @@ struct SM90_64x208x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34605,7 +35078,7 @@ struct SM90_64x208x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34619,7 +35092,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E4M3E4M3_SS_TN +struct MMA_64x224x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34646,6 +35119,7 @@ struct SM90_64x224x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34680,7 +35154,7 @@ struct SM90_64x224x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34694,7 +35168,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E4M3E4M3_RS_TN +struct MMA_64x224x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34721,6 +35195,7 @@ struct SM90_64x224x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34755,7 +35230,7 @@ struct SM90_64x224x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34769,7 +35244,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E4M3E4M3_SS_TN +struct MMA_64x224x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34810,6 +35285,7 @@ struct SM90_64x224x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34865,7 +35341,7 @@ struct SM90_64x224x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34879,7 +35355,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E4M3E4M3_RS_TN +struct MMA_64x224x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -34920,6 +35396,7 @@ struct SM90_64x224x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -34975,7 +35452,7 @@ struct SM90_64x224x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -34989,7 +35466,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E4M3E4M3_SS_TN +struct MMA_64x240x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35017,6 +35494,7 @@ struct SM90_64x240x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35053,7 +35531,7 @@ struct SM90_64x240x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35067,7 +35545,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E4M3E4M3_RS_TN +struct MMA_64x240x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35095,6 +35573,7 @@ struct SM90_64x240x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35131,7 +35610,7 @@ struct SM90_64x240x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35145,7 +35624,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E4M3E4M3_SS_TN +struct MMA_64x240x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35188,6 +35667,7 @@ struct SM90_64x240x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35246,7 +35726,7 @@ struct SM90_64x240x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35260,7 +35740,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E4M3E4M3_RS_TN +struct MMA_64x240x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35303,6 +35783,7 @@ struct SM90_64x240x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35361,7 +35842,7 @@ struct SM90_64x240x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35374,7 +35855,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E4M3E4M3_SS_TN +struct MMA_64x256x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35403,6 +35884,7 @@ struct SM90_64x256x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35440,7 +35922,7 @@ struct SM90_64x256x32_F16E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35452,7 +35934,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E4M3E4M3_RS_TN +struct MMA_64x256x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35481,6 +35963,7 @@ struct SM90_64x256x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35518,7 +36001,7 @@ struct SM90_64x256x32_F16E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35530,7 +36013,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E4M3E4M3_SS_TN +struct MMA_64x256x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35575,6 +36058,7 @@ struct SM90_64x256x32_F32E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35636,7 +36120,7 @@ struct SM90_64x256x32_F32E4M3E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35648,7 +36132,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E4M3E4M3_RS_TN +struct MMA_64x256x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35693,6 +36177,7 @@ struct SM90_64x256x32_F32E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35754,7 +36239,7 @@ struct SM90_64x256x32_F32E4M3E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35766,7 +36251,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E4M3E5M2_SS_TN +struct MMA_64x8x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35780,6 +36265,7 @@ struct SM90_64x8x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35795,7 +36281,7 @@ struct SM90_64x8x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35807,7 +36293,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E4M3E5M2_RS_TN +struct MMA_64x8x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35821,6 +36307,7 @@ struct SM90_64x8x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35836,7 +36323,7 @@ struct SM90_64x8x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35848,7 +36335,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E4M3E5M2_SS_TN +struct MMA_64x8x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35862,6 +36349,7 @@ struct SM90_64x8x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35877,7 +36365,7 @@ struct SM90_64x8x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35889,7 +36377,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E4M3E5M2_RS_TN +struct MMA_64x8x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35903,6 +36391,7 @@ struct SM90_64x8x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35918,7 +36407,7 @@ struct SM90_64x8x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35930,7 +36419,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E4M3E5M2_SS_TN +struct MMA_64x16x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -35944,6 +36433,7 @@ struct SM90_64x16x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -35959,7 +36449,7 @@ struct SM90_64x16x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -35971,7 +36461,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E4M3E5M2_RS_TN +struct MMA_64x16x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -35985,6 +36475,7 @@ struct SM90_64x16x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36000,7 +36491,7 @@ struct SM90_64x16x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36012,7 +36503,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E4M3E5M2_SS_TN +struct MMA_64x16x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36027,6 +36518,7 @@ struct SM90_64x16x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36043,7 +36535,7 @@ struct SM90_64x16x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36055,7 +36547,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E4M3E5M2_RS_TN +struct MMA_64x16x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36070,6 +36562,7 @@ struct SM90_64x16x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36086,7 +36579,7 @@ struct SM90_64x16x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36098,7 +36591,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E4M3E5M2_SS_TN +struct MMA_64x32x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36113,6 +36606,7 @@ struct SM90_64x32x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36129,7 +36623,7 @@ struct SM90_64x32x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36141,7 +36635,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E4M3E5M2_RS_TN +struct MMA_64x32x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36156,6 +36650,7 @@ struct SM90_64x32x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36172,7 +36667,7 @@ struct SM90_64x32x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36184,7 +36679,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E4M3E5M2_SS_TN +struct MMA_64x32x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36201,6 +36696,7 @@ struct SM90_64x32x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36220,7 +36716,7 @@ struct SM90_64x32x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36232,7 +36728,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E4M3E5M2_RS_TN +struct MMA_64x32x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36249,6 +36745,7 @@ struct SM90_64x32x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36268,7 +36765,7 @@ struct SM90_64x32x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36281,7 +36778,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E4M3E5M2_SS_TN +struct MMA_64x48x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36297,6 +36794,7 @@ struct SM90_64x48x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36315,7 +36813,7 @@ struct SM90_64x48x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36329,7 +36827,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E4M3E5M2_RS_TN +struct MMA_64x48x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36345,6 +36843,7 @@ struct SM90_64x48x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36363,7 +36862,7 @@ struct SM90_64x48x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36377,7 +36876,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E4M3E5M2_SS_TN +struct MMA_64x48x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36396,6 +36895,7 @@ struct SM90_64x48x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36418,7 +36918,7 @@ struct SM90_64x48x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36432,7 +36932,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E4M3E5M2_RS_TN +struct MMA_64x48x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36451,6 +36951,7 @@ struct SM90_64x48x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36473,7 +36974,7 @@ struct SM90_64x48x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36486,7 +36987,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E4M3E5M2_SS_TN +struct MMA_64x64x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36503,6 +37004,7 @@ struct SM90_64x64x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36522,7 +37024,7 @@ struct SM90_64x64x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36534,7 +37036,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E4M3E5M2_RS_TN +struct MMA_64x64x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36551,6 +37053,7 @@ struct SM90_64x64x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36570,7 +37073,7 @@ struct SM90_64x64x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36582,7 +37085,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E4M3E5M2_SS_TN +struct MMA_64x64x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36603,6 +37106,7 @@ struct SM90_64x64x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36628,7 +37132,7 @@ struct SM90_64x64x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36640,7 +37144,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E4M3E5M2_RS_TN +struct MMA_64x64x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36661,6 +37165,7 @@ struct SM90_64x64x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36686,7 +37191,7 @@ struct SM90_64x64x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36699,7 +37204,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E4M3E5M2_SS_TN +struct MMA_64x80x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36717,6 +37222,7 @@ struct SM90_64x80x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36738,7 +37244,7 @@ struct SM90_64x80x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36752,7 +37258,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E4M3E5M2_RS_TN +struct MMA_64x80x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36770,6 +37276,7 @@ struct SM90_64x80x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36791,7 +37298,7 @@ struct SM90_64x80x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36805,7 +37312,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E4M3E5M2_SS_TN +struct MMA_64x80x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36828,6 +37335,7 @@ struct SM90_64x80x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36856,7 +37364,7 @@ struct SM90_64x80x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36870,7 +37378,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E4M3E5M2_RS_TN +struct MMA_64x80x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36893,6 +37401,7 @@ struct SM90_64x80x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36921,7 +37430,7 @@ struct SM90_64x80x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36934,7 +37443,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E4M3E5M2_SS_TN +struct MMA_64x96x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36953,6 +37462,7 @@ struct SM90_64x96x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -36975,7 +37485,7 @@ struct SM90_64x96x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -36987,7 +37497,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E4M3E5M2_RS_TN +struct MMA_64x96x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37006,6 +37516,7 @@ struct SM90_64x96x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37028,7 +37539,7 @@ struct SM90_64x96x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37040,7 +37551,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E4M3E5M2_SS_TN +struct MMA_64x96x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37065,6 +37576,7 @@ struct SM90_64x96x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37096,7 +37608,7 @@ struct SM90_64x96x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37108,7 +37620,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E4M3E5M2_RS_TN +struct MMA_64x96x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37133,6 +37645,7 @@ struct SM90_64x96x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37164,7 +37677,7 @@ struct SM90_64x96x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37177,7 +37690,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E4M3E5M2_SS_TN +struct MMA_64x112x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37197,6 +37710,7 @@ struct SM90_64x112x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37221,7 +37735,7 @@ struct SM90_64x112x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37235,7 +37749,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E4M3E5M2_RS_TN +struct MMA_64x112x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37255,6 +37769,7 @@ struct SM90_64x112x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37279,7 +37794,7 @@ struct SM90_64x112x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37293,7 +37808,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E4M3E5M2_SS_TN +struct MMA_64x112x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37320,6 +37835,7 @@ struct SM90_64x112x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37354,7 +37870,7 @@ struct SM90_64x112x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37368,7 +37884,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E4M3E5M2_RS_TN +struct MMA_64x112x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37395,6 +37911,7 @@ struct SM90_64x112x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37429,7 +37946,7 @@ struct SM90_64x112x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37442,7 +37959,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E4M3E5M2_SS_TN +struct MMA_64x128x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37463,6 +37980,7 @@ struct SM90_64x128x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37488,7 +38006,7 @@ struct SM90_64x128x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37500,7 +38018,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E4M3E5M2_RS_TN +struct MMA_64x128x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37521,6 +38039,7 @@ struct SM90_64x128x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37546,7 +38065,7 @@ struct SM90_64x128x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37558,7 +38077,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E4M3E5M2_SS_TN +struct MMA_64x128x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37587,6 +38106,7 @@ struct SM90_64x128x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37624,7 +38144,7 @@ struct SM90_64x128x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37636,7 +38156,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E4M3E5M2_RS_TN +struct MMA_64x128x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37665,6 +38185,7 @@ struct SM90_64x128x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37702,7 +38223,7 @@ struct SM90_64x128x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37715,7 +38236,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E4M3E5M2_SS_TN +struct MMA_64x144x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37737,6 +38258,7 @@ struct SM90_64x144x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37764,7 +38286,7 @@ struct SM90_64x144x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37778,7 +38300,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E4M3E5M2_RS_TN +struct MMA_64x144x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37800,6 +38322,7 @@ struct SM90_64x144x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37827,7 +38350,7 @@ struct SM90_64x144x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37841,7 +38364,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E4M3E5M2_SS_TN +struct MMA_64x144x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -37872,6 +38395,7 @@ struct SM90_64x144x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37912,7 +38436,7 @@ struct SM90_64x144x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -37926,7 +38450,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E4M3E5M2_RS_TN +struct MMA_64x144x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -37957,6 +38481,7 @@ struct SM90_64x144x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -37997,7 +38522,7 @@ struct SM90_64x144x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38011,7 +38536,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E4M3E5M2_SS_TN +struct MMA_64x160x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38034,6 +38559,7 @@ struct SM90_64x160x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38062,7 +38588,7 @@ struct SM90_64x160x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38076,7 +38602,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E4M3E5M2_RS_TN +struct MMA_64x160x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38099,6 +38625,7 @@ struct SM90_64x160x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38127,7 +38654,7 @@ struct SM90_64x160x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38141,7 +38668,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E4M3E5M2_SS_TN +struct MMA_64x160x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38174,6 +38701,7 @@ struct SM90_64x160x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38217,7 +38745,7 @@ struct SM90_64x160x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38231,7 +38759,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E4M3E5M2_RS_TN +struct MMA_64x160x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38264,6 +38792,7 @@ struct SM90_64x160x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38307,7 +38836,7 @@ struct SM90_64x160x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38321,7 +38850,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E4M3E5M2_SS_TN +struct MMA_64x176x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38345,6 +38874,7 @@ struct SM90_64x176x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38375,7 +38905,7 @@ struct SM90_64x176x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38389,7 +38919,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E4M3E5M2_RS_TN +struct MMA_64x176x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38413,6 +38943,7 @@ struct SM90_64x176x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38443,7 +38974,7 @@ struct SM90_64x176x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38457,7 +38988,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E4M3E5M2_SS_TN +struct MMA_64x176x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38492,6 +39023,7 @@ struct SM90_64x176x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38538,7 +39070,7 @@ struct SM90_64x176x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38552,7 +39084,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E4M3E5M2_RS_TN +struct MMA_64x176x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38587,6 +39119,7 @@ struct SM90_64x176x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38633,7 +39166,7 @@ struct SM90_64x176x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38646,7 +39179,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E4M3E5M2_SS_TN +struct MMA_64x192x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38671,6 +39204,7 @@ struct SM90_64x192x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38702,7 +39236,7 @@ struct SM90_64x192x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38714,7 +39248,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E4M3E5M2_RS_TN +struct MMA_64x192x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38739,6 +39273,7 @@ struct SM90_64x192x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38770,7 +39305,7 @@ struct SM90_64x192x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38782,7 +39317,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E4M3E5M2_SS_TN +struct MMA_64x192x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38819,6 +39354,7 @@ struct SM90_64x192x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38868,7 +39404,7 @@ struct SM90_64x192x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38880,7 +39416,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E4M3E5M2_RS_TN +struct MMA_64x192x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38917,6 +39453,7 @@ struct SM90_64x192x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -38966,7 +39503,7 @@ struct SM90_64x192x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -38979,7 +39516,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E4M3E5M2_SS_TN +struct MMA_64x208x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39005,6 +39542,7 @@ struct SM90_64x208x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39038,7 +39576,7 @@ struct SM90_64x208x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39052,7 +39590,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E4M3E5M2_RS_TN +struct MMA_64x208x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39078,6 +39616,7 @@ struct SM90_64x208x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39111,7 +39650,7 @@ struct SM90_64x208x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39125,7 +39664,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E4M3E5M2_SS_TN +struct MMA_64x208x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39164,6 +39703,7 @@ struct SM90_64x208x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39216,7 +39756,7 @@ struct SM90_64x208x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39230,7 +39770,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E4M3E5M2_RS_TN +struct MMA_64x208x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39269,6 +39809,7 @@ struct SM90_64x208x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39321,7 +39862,7 @@ struct SM90_64x208x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39335,7 +39876,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E4M3E5M2_SS_TN +struct MMA_64x224x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39362,6 +39903,7 @@ struct SM90_64x224x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39396,7 +39938,7 @@ struct SM90_64x224x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39410,7 +39952,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E4M3E5M2_RS_TN +struct MMA_64x224x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39437,6 +39979,7 @@ struct SM90_64x224x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39471,7 +40014,7 @@ struct SM90_64x224x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39485,7 +40028,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E4M3E5M2_SS_TN +struct MMA_64x224x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39526,6 +40069,7 @@ struct SM90_64x224x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39581,7 +40125,7 @@ struct SM90_64x224x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39595,7 +40139,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E4M3E5M2_RS_TN +struct MMA_64x224x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39636,6 +40180,7 @@ struct SM90_64x224x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39691,7 +40236,7 @@ struct SM90_64x224x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39705,7 +40250,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E4M3E5M2_SS_TN +struct MMA_64x240x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39733,6 +40278,7 @@ struct SM90_64x240x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39769,7 +40315,7 @@ struct SM90_64x240x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39783,7 +40329,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E4M3E5M2_RS_TN +struct MMA_64x240x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -39811,6 +40357,7 @@ struct SM90_64x240x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39847,7 +40394,7 @@ struct SM90_64x240x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39861,7 +40408,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E4M3E5M2_SS_TN +struct MMA_64x240x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -39904,6 +40451,7 @@ struct SM90_64x240x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -39962,7 +40510,7 @@ struct SM90_64x240x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -39976,7 +40524,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E4M3E5M2_RS_TN +struct MMA_64x240x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40019,6 +40567,7 @@ struct SM90_64x240x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40077,7 +40626,7 @@ struct SM90_64x240x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40090,7 +40639,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E4M3E5M2_SS_TN +struct MMA_64x256x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40119,6 +40668,7 @@ struct SM90_64x256x32_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40156,7 +40706,7 @@ struct SM90_64x256x32_F16E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40168,7 +40718,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E4M3E5M2_RS_TN +struct MMA_64x256x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40197,6 +40747,7 @@ struct SM90_64x256x32_F16E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40234,7 +40785,7 @@ struct SM90_64x256x32_F16E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40246,7 +40797,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E4M3E5M2_SS_TN +struct MMA_64x256x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40291,6 +40842,7 @@ struct SM90_64x256x32_F32E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40352,7 +40904,7 @@ struct SM90_64x256x32_F32E4M3E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40364,7 +40916,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E4M3E5M2_RS_TN +struct MMA_64x256x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40409,6 +40961,7 @@ struct SM90_64x256x32_F32E4M3E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40470,7 +41023,7 @@ struct SM90_64x256x32_F32E4M3E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40482,7 +41035,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E5M2E4M3_SS_TN +struct MMA_64x8x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40496,6 +41049,7 @@ struct SM90_64x8x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40511,7 +41065,7 @@ struct SM90_64x8x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40523,7 +41077,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E5M2E4M3_RS_TN +struct MMA_64x8x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40537,6 +41091,7 @@ struct SM90_64x8x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40552,7 +41107,7 @@ struct SM90_64x8x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40564,7 +41119,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E5M2E4M3_SS_TN +struct MMA_64x8x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40578,6 +41133,7 @@ struct SM90_64x8x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40593,7 +41149,7 @@ struct SM90_64x8x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40605,7 +41161,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E5M2E4M3_RS_TN +struct MMA_64x8x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40619,6 +41175,7 @@ struct SM90_64x8x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40634,7 +41191,7 @@ struct SM90_64x8x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40646,7 +41203,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E5M2E4M3_SS_TN +struct MMA_64x16x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40660,6 +41217,7 @@ struct SM90_64x16x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40675,7 +41233,7 @@ struct SM90_64x16x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40687,7 +41245,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E5M2E4M3_RS_TN +struct MMA_64x16x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40701,6 +41259,7 @@ struct SM90_64x16x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40716,7 +41275,7 @@ struct SM90_64x16x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40728,7 +41287,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E5M2E4M3_SS_TN +struct MMA_64x16x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40743,6 +41302,7 @@ struct SM90_64x16x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40759,7 +41319,7 @@ struct SM90_64x16x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40771,7 +41331,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E5M2E4M3_RS_TN +struct MMA_64x16x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40786,6 +41346,7 @@ struct SM90_64x16x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40802,7 +41363,7 @@ struct SM90_64x16x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40814,7 +41375,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E5M2E4M3_SS_TN +struct MMA_64x32x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40829,6 +41390,7 @@ struct SM90_64x32x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40845,7 +41407,7 @@ struct SM90_64x32x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40857,7 +41419,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E5M2E4M3_RS_TN +struct MMA_64x32x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40872,6 +41434,7 @@ struct SM90_64x32x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40888,7 +41451,7 @@ struct SM90_64x32x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40900,7 +41463,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E5M2E4M3_SS_TN +struct MMA_64x32x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -40917,6 +41480,7 @@ struct SM90_64x32x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40936,7 +41500,7 @@ struct SM90_64x32x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40948,7 +41512,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E5M2E4M3_RS_TN +struct MMA_64x32x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -40965,6 +41529,7 @@ struct SM90_64x32x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -40984,7 +41549,7 @@ struct SM90_64x32x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -40997,7 +41562,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E5M2E4M3_SS_TN +struct MMA_64x48x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41013,6 +41578,7 @@ struct SM90_64x48x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41031,7 +41597,7 @@ struct SM90_64x48x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41045,7 +41611,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E5M2E4M3_RS_TN +struct MMA_64x48x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41061,6 +41627,7 @@ struct SM90_64x48x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41079,7 +41646,7 @@ struct SM90_64x48x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41093,7 +41660,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E5M2E4M3_SS_TN +struct MMA_64x48x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41112,6 +41679,7 @@ struct SM90_64x48x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41134,7 +41702,7 @@ struct SM90_64x48x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41148,7 +41716,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E5M2E4M3_RS_TN +struct MMA_64x48x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41167,6 +41735,7 @@ struct SM90_64x48x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41189,7 +41758,7 @@ struct SM90_64x48x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41202,7 +41771,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E5M2E4M3_SS_TN +struct MMA_64x64x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41219,6 +41788,7 @@ struct SM90_64x64x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41238,7 +41808,7 @@ struct SM90_64x64x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41250,7 +41820,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E5M2E4M3_RS_TN +struct MMA_64x64x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41267,6 +41837,7 @@ struct SM90_64x64x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41286,7 +41857,7 @@ struct SM90_64x64x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41298,7 +41869,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E5M2E4M3_SS_TN +struct MMA_64x64x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41319,6 +41890,7 @@ struct SM90_64x64x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41344,7 +41916,7 @@ struct SM90_64x64x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41356,7 +41928,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E5M2E4M3_RS_TN +struct MMA_64x64x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41377,6 +41949,7 @@ struct SM90_64x64x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41402,7 +41975,7 @@ struct SM90_64x64x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41415,7 +41988,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E5M2E4M3_SS_TN +struct MMA_64x80x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41433,6 +42006,7 @@ struct SM90_64x80x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41454,7 +42028,7 @@ struct SM90_64x80x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41468,7 +42042,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E5M2E4M3_RS_TN +struct MMA_64x80x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41486,6 +42060,7 @@ struct SM90_64x80x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41507,7 +42082,7 @@ struct SM90_64x80x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41521,7 +42096,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E5M2E4M3_SS_TN +struct MMA_64x80x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41544,6 +42119,7 @@ struct SM90_64x80x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41572,7 +42148,7 @@ struct SM90_64x80x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41586,7 +42162,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E5M2E4M3_RS_TN +struct MMA_64x80x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41609,6 +42185,7 @@ struct SM90_64x80x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41637,7 +42214,7 @@ struct SM90_64x80x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41650,7 +42227,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E5M2E4M3_SS_TN +struct MMA_64x96x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41669,6 +42246,7 @@ struct SM90_64x96x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41691,7 +42269,7 @@ struct SM90_64x96x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41703,7 +42281,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E5M2E4M3_RS_TN +struct MMA_64x96x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41722,6 +42300,7 @@ struct SM90_64x96x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41744,7 +42323,7 @@ struct SM90_64x96x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41756,7 +42335,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E5M2E4M3_SS_TN +struct MMA_64x96x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41781,6 +42360,7 @@ struct SM90_64x96x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41812,7 +42392,7 @@ struct SM90_64x96x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41824,7 +42404,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E5M2E4M3_RS_TN +struct MMA_64x96x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41849,6 +42429,7 @@ struct SM90_64x96x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41880,7 +42461,7 @@ struct SM90_64x96x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41893,7 +42474,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E5M2E4M3_SS_TN +struct MMA_64x112x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41913,6 +42494,7 @@ struct SM90_64x112x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41937,7 +42519,7 @@ struct SM90_64x112x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -41951,7 +42533,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E5M2E4M3_RS_TN +struct MMA_64x112x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41971,6 +42553,7 @@ struct SM90_64x112x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -41995,7 +42578,7 @@ struct SM90_64x112x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42009,7 +42592,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E5M2E4M3_SS_TN +struct MMA_64x112x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42036,6 +42619,7 @@ struct SM90_64x112x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42070,7 +42654,7 @@ struct SM90_64x112x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42084,7 +42668,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E5M2E4M3_RS_TN +struct MMA_64x112x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42111,6 +42695,7 @@ struct SM90_64x112x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42145,7 +42730,7 @@ struct SM90_64x112x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42158,7 +42743,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E5M2E4M3_SS_TN +struct MMA_64x128x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42179,6 +42764,7 @@ struct SM90_64x128x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42204,7 +42790,7 @@ struct SM90_64x128x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42216,7 +42802,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E5M2E4M3_RS_TN +struct MMA_64x128x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42237,6 +42823,7 @@ struct SM90_64x128x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42262,7 +42849,7 @@ struct SM90_64x128x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42274,7 +42861,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E5M2E4M3_SS_TN +struct MMA_64x128x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42303,6 +42890,7 @@ struct SM90_64x128x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42340,7 +42928,7 @@ struct SM90_64x128x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42352,7 +42940,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E5M2E4M3_RS_TN +struct MMA_64x128x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42381,6 +42969,7 @@ struct SM90_64x128x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42418,7 +43007,7 @@ struct SM90_64x128x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42431,7 +43020,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E5M2E4M3_SS_TN +struct MMA_64x144x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42453,6 +43042,7 @@ struct SM90_64x144x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42480,7 +43070,7 @@ struct SM90_64x144x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42494,7 +43084,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E5M2E4M3_RS_TN +struct MMA_64x144x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42516,6 +43106,7 @@ struct SM90_64x144x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42543,7 +43134,7 @@ struct SM90_64x144x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42557,7 +43148,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E5M2E4M3_SS_TN +struct MMA_64x144x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42588,6 +43179,7 @@ struct SM90_64x144x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42628,7 +43220,7 @@ struct SM90_64x144x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42642,7 +43234,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E5M2E4M3_RS_TN +struct MMA_64x144x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42673,6 +43265,7 @@ struct SM90_64x144x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42713,7 +43306,7 @@ struct SM90_64x144x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42727,7 +43320,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E5M2E4M3_SS_TN +struct MMA_64x160x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42750,6 +43343,7 @@ struct SM90_64x160x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42778,7 +43372,7 @@ struct SM90_64x160x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42792,7 +43386,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E5M2E4M3_RS_TN +struct MMA_64x160x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42815,6 +43409,7 @@ struct SM90_64x160x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42843,7 +43438,7 @@ struct SM90_64x160x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42857,7 +43452,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E5M2E4M3_SS_TN +struct MMA_64x160x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42890,6 +43485,7 @@ struct SM90_64x160x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -42933,7 +43529,7 @@ struct SM90_64x160x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -42947,7 +43543,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E5M2E4M3_RS_TN +struct MMA_64x160x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42980,6 +43576,7 @@ struct SM90_64x160x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43023,7 +43620,7 @@ struct SM90_64x160x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43037,7 +43634,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E5M2E4M3_SS_TN +struct MMA_64x176x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43061,6 +43658,7 @@ struct SM90_64x176x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43091,7 +43689,7 @@ struct SM90_64x176x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43105,7 +43703,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E5M2E4M3_RS_TN +struct MMA_64x176x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43129,6 +43727,7 @@ struct SM90_64x176x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43159,7 +43758,7 @@ struct SM90_64x176x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43173,7 +43772,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E5M2E4M3_SS_TN +struct MMA_64x176x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43208,6 +43807,7 @@ struct SM90_64x176x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43254,7 +43854,7 @@ struct SM90_64x176x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43268,7 +43868,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E5M2E4M3_RS_TN +struct MMA_64x176x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43303,6 +43903,7 @@ struct SM90_64x176x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43349,7 +43950,7 @@ struct SM90_64x176x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43362,7 +43963,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E5M2E4M3_SS_TN +struct MMA_64x192x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43387,6 +43988,7 @@ struct SM90_64x192x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43418,7 +44020,7 @@ struct SM90_64x192x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43430,7 +44032,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E5M2E4M3_RS_TN +struct MMA_64x192x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43455,6 +44057,7 @@ struct SM90_64x192x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43486,7 +44089,7 @@ struct SM90_64x192x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43498,7 +44101,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E5M2E4M3_SS_TN +struct MMA_64x192x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43535,6 +44138,7 @@ struct SM90_64x192x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43584,7 +44188,7 @@ struct SM90_64x192x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43596,7 +44200,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E5M2E4M3_RS_TN +struct MMA_64x192x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43633,6 +44237,7 @@ struct SM90_64x192x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43682,7 +44287,7 @@ struct SM90_64x192x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43695,7 +44300,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E5M2E4M3_SS_TN +struct MMA_64x208x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43721,6 +44326,7 @@ struct SM90_64x208x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43754,7 +44360,7 @@ struct SM90_64x208x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43768,7 +44374,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E5M2E4M3_RS_TN +struct MMA_64x208x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43794,6 +44400,7 @@ struct SM90_64x208x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43827,7 +44434,7 @@ struct SM90_64x208x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43841,7 +44448,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E5M2E4M3_SS_TN +struct MMA_64x208x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43880,6 +44487,7 @@ struct SM90_64x208x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -43932,7 +44540,7 @@ struct SM90_64x208x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -43946,7 +44554,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E5M2E4M3_RS_TN +struct MMA_64x208x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43985,6 +44593,7 @@ struct SM90_64x208x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44037,7 +44646,7 @@ struct SM90_64x208x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44051,7 +44660,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E5M2E4M3_SS_TN +struct MMA_64x224x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44078,6 +44687,7 @@ struct SM90_64x224x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44112,7 +44722,7 @@ struct SM90_64x224x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44126,7 +44736,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E5M2E4M3_RS_TN +struct MMA_64x224x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44153,6 +44763,7 @@ struct SM90_64x224x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44187,7 +44798,7 @@ struct SM90_64x224x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44201,7 +44812,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E5M2E4M3_SS_TN +struct MMA_64x224x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44242,6 +44853,7 @@ struct SM90_64x224x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44297,7 +44909,7 @@ struct SM90_64x224x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44311,7 +44923,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E5M2E4M3_RS_TN +struct MMA_64x224x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44352,6 +44964,7 @@ struct SM90_64x224x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44407,7 +45020,7 @@ struct SM90_64x224x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44421,7 +45034,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E5M2E4M3_SS_TN +struct MMA_64x240x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44449,6 +45062,7 @@ struct SM90_64x240x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44485,7 +45099,7 @@ struct SM90_64x240x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44499,7 +45113,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E5M2E4M3_RS_TN +struct MMA_64x240x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44527,6 +45141,7 @@ struct SM90_64x240x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44563,7 +45178,7 @@ struct SM90_64x240x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44577,7 +45192,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E5M2E4M3_SS_TN +struct MMA_64x240x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44620,6 +45235,7 @@ struct SM90_64x240x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44678,7 +45294,7 @@ struct SM90_64x240x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44692,7 +45308,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E5M2E4M3_RS_TN +struct MMA_64x240x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44735,6 +45351,7 @@ struct SM90_64x240x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44793,7 +45410,7 @@ struct SM90_64x240x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44806,7 +45423,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E5M2E4M3_SS_TN +struct MMA_64x256x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44835,6 +45452,7 @@ struct SM90_64x256x32_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44872,7 +45490,7 @@ struct SM90_64x256x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44884,7 +45502,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E5M2E4M3_RS_TN +struct MMA_64x256x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44913,6 +45531,7 @@ struct SM90_64x256x32_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -44950,7 +45569,7 @@ struct SM90_64x256x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -44962,7 +45581,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E5M2E4M3_SS_TN +struct MMA_64x256x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45007,6 +45626,7 @@ struct SM90_64x256x32_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45068,7 +45688,7 @@ struct SM90_64x256x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45080,7 +45700,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E5M2E4M3_RS_TN +struct MMA_64x256x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45125,6 +45745,7 @@ struct SM90_64x256x32_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45186,7 +45807,7 @@ struct SM90_64x256x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45198,7 +45819,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E5M2E5M2_SS_TN +struct MMA_64x8x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45212,6 +45833,7 @@ struct SM90_64x8x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45227,7 +45849,7 @@ struct SM90_64x8x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45239,7 +45861,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F16E5M2E5M2_RS_TN +struct MMA_64x8x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45253,6 +45875,7 @@ struct SM90_64x8x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45268,7 +45891,7 @@ struct SM90_64x8x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45280,7 +45903,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E5M2E5M2_SS_TN +struct MMA_64x8x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45294,6 +45917,7 @@ struct SM90_64x8x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45309,7 +45933,7 @@ struct SM90_64x8x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45321,7 +45945,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x8x32_F32E5M2E5M2_RS_TN +struct MMA_64x8x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45335,6 +45959,7 @@ struct SM90_64x8x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45350,7 +45975,7 @@ struct SM90_64x8x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45362,7 +45987,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E5M2E5M2_SS_TN +struct MMA_64x16x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45376,6 +46001,7 @@ struct SM90_64x16x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45391,7 +46017,7 @@ struct SM90_64x16x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45403,7 +46029,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F16E5M2E5M2_RS_TN +struct MMA_64x16x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45417,6 +46043,7 @@ struct SM90_64x16x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45432,7 +46059,7 @@ struct SM90_64x16x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45444,7 +46071,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E5M2E5M2_SS_TN +struct MMA_64x16x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45459,6 +46086,7 @@ struct SM90_64x16x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45475,7 +46103,7 @@ struct SM90_64x16x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45487,7 +46115,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x16x32_F32E5M2E5M2_RS_TN +struct MMA_64x16x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45502,6 +46130,7 @@ struct SM90_64x16x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45518,7 +46147,7 @@ struct SM90_64x16x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45530,7 +46159,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E5M2E5M2_SS_TN +struct MMA_64x32x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45545,6 +46174,7 @@ struct SM90_64x32x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45561,7 +46191,7 @@ struct SM90_64x32x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45573,7 +46203,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F16E5M2E5M2_RS_TN +struct MMA_64x32x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45588,6 +46218,7 @@ struct SM90_64x32x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45604,7 +46235,7 @@ struct SM90_64x32x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45616,7 +46247,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E5M2E5M2_SS_TN +struct MMA_64x32x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45633,6 +46264,7 @@ struct SM90_64x32x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45652,7 +46284,7 @@ struct SM90_64x32x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45664,7 +46296,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x32x32_F32E5M2E5M2_RS_TN +struct MMA_64x32x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45681,6 +46313,7 @@ struct SM90_64x32x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45700,7 +46333,7 @@ struct SM90_64x32x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45713,7 +46346,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E5M2E5M2_SS_TN +struct MMA_64x48x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45729,6 +46362,7 @@ struct SM90_64x48x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45747,7 +46381,7 @@ struct SM90_64x48x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45761,7 +46395,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F16E5M2E5M2_RS_TN +struct MMA_64x48x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45777,6 +46411,7 @@ struct SM90_64x48x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45795,7 +46430,7 @@ struct SM90_64x48x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45809,7 +46444,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E5M2E5M2_SS_TN +struct MMA_64x48x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45828,6 +46463,7 @@ struct SM90_64x48x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45850,7 +46486,7 @@ struct SM90_64x48x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45864,7 +46500,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x48x32_F32E5M2E5M2_RS_TN +struct MMA_64x48x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45883,6 +46519,7 @@ struct SM90_64x48x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45905,7 +46542,7 @@ struct SM90_64x48x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45918,7 +46555,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E5M2E5M2_SS_TN +struct MMA_64x64x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45935,6 +46572,7 @@ struct SM90_64x64x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -45954,7 +46592,7 @@ struct SM90_64x64x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -45966,7 +46604,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F16E5M2E5M2_RS_TN +struct MMA_64x64x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45983,6 +46621,7 @@ struct SM90_64x64x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46002,7 +46641,7 @@ struct SM90_64x64x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46014,7 +46653,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E5M2E5M2_SS_TN +struct MMA_64x64x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46035,6 +46674,7 @@ struct SM90_64x64x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46060,7 +46700,7 @@ struct SM90_64x64x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46072,7 +46712,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x64x32_F32E5M2E5M2_RS_TN +struct MMA_64x64x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46093,6 +46733,7 @@ struct SM90_64x64x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46118,7 +46759,7 @@ struct SM90_64x64x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46131,7 +46772,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E5M2E5M2_SS_TN +struct MMA_64x80x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46149,6 +46790,7 @@ struct SM90_64x80x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46170,7 +46812,7 @@ struct SM90_64x80x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46184,7 +46826,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F16E5M2E5M2_RS_TN +struct MMA_64x80x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46202,6 +46844,7 @@ struct SM90_64x80x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46223,7 +46866,7 @@ struct SM90_64x80x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46237,7 +46880,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E5M2E5M2_SS_TN +struct MMA_64x80x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46260,6 +46903,7 @@ struct SM90_64x80x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46288,7 +46932,7 @@ struct SM90_64x80x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46302,7 +46946,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x80x32_F32E5M2E5M2_RS_TN +struct MMA_64x80x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46325,6 +46969,7 @@ struct SM90_64x80x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46353,7 +46998,7 @@ struct SM90_64x80x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46366,7 +47011,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E5M2E5M2_SS_TN +struct MMA_64x96x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46385,6 +47030,7 @@ struct SM90_64x96x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46407,7 +47053,7 @@ struct SM90_64x96x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46419,7 +47065,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F16E5M2E5M2_RS_TN +struct MMA_64x96x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46438,6 +47084,7 @@ struct SM90_64x96x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46460,7 +47107,7 @@ struct SM90_64x96x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46472,7 +47119,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E5M2E5M2_SS_TN +struct MMA_64x96x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46497,6 +47144,7 @@ struct SM90_64x96x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46528,7 +47176,7 @@ struct SM90_64x96x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46540,7 +47188,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x96x32_F32E5M2E5M2_RS_TN +struct MMA_64x96x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46565,6 +47213,7 @@ struct SM90_64x96x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46596,7 +47245,7 @@ struct SM90_64x96x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46609,7 +47258,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E5M2E5M2_SS_TN +struct MMA_64x112x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46629,6 +47278,7 @@ struct SM90_64x112x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46653,7 +47303,7 @@ struct SM90_64x112x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46667,7 +47317,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F16E5M2E5M2_RS_TN +struct MMA_64x112x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46687,6 +47337,7 @@ struct SM90_64x112x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46711,7 +47362,7 @@ struct SM90_64x112x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46725,7 +47376,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E5M2E5M2_SS_TN +struct MMA_64x112x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46752,6 +47403,7 @@ struct SM90_64x112x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46786,7 +47438,7 @@ struct SM90_64x112x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46800,7 +47452,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x112x32_F32E5M2E5M2_RS_TN +struct MMA_64x112x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46827,6 +47479,7 @@ struct SM90_64x112x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46861,7 +47514,7 @@ struct SM90_64x112x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46874,7 +47527,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E5M2E5M2_SS_TN +struct MMA_64x128x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46895,6 +47548,7 @@ struct SM90_64x128x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46920,7 +47574,7 @@ struct SM90_64x128x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46932,7 +47586,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F16E5M2E5M2_RS_TN +struct MMA_64x128x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46953,6 +47607,7 @@ struct SM90_64x128x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -46978,7 +47633,7 @@ struct SM90_64x128x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -46990,7 +47645,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E5M2E5M2_SS_TN +struct MMA_64x128x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47019,6 +47674,7 @@ struct SM90_64x128x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47056,7 +47712,7 @@ struct SM90_64x128x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47068,7 +47724,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x128x32_F32E5M2E5M2_RS_TN +struct MMA_64x128x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47097,6 +47753,7 @@ struct SM90_64x128x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47134,7 +47791,7 @@ struct SM90_64x128x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47147,7 +47804,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E5M2E5M2_SS_TN +struct MMA_64x144x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47169,6 +47826,7 @@ struct SM90_64x144x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47196,7 +47854,7 @@ struct SM90_64x144x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47210,7 +47868,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F16E5M2E5M2_RS_TN +struct MMA_64x144x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47232,6 +47890,7 @@ struct SM90_64x144x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47259,7 +47918,7 @@ struct SM90_64x144x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47273,7 +47932,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E5M2E5M2_SS_TN +struct MMA_64x144x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47304,6 +47963,7 @@ struct SM90_64x144x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47344,7 +48004,7 @@ struct SM90_64x144x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47358,7 +48018,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x144x32_F32E5M2E5M2_RS_TN +struct MMA_64x144x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47389,6 +48049,7 @@ struct SM90_64x144x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47429,7 +48090,7 @@ struct SM90_64x144x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47443,7 +48104,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E5M2E5M2_SS_TN +struct MMA_64x160x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47466,6 +48127,7 @@ struct SM90_64x160x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47494,7 +48156,7 @@ struct SM90_64x160x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47508,7 +48170,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F16E5M2E5M2_RS_TN +struct MMA_64x160x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47531,6 +48193,7 @@ struct SM90_64x160x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47559,7 +48222,7 @@ struct SM90_64x160x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47573,7 +48236,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E5M2E5M2_SS_TN +struct MMA_64x160x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47606,6 +48269,7 @@ struct SM90_64x160x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47649,7 +48313,7 @@ struct SM90_64x160x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47663,7 +48327,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x160x32_F32E5M2E5M2_RS_TN +struct MMA_64x160x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47696,6 +48360,7 @@ struct SM90_64x160x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47739,7 +48404,7 @@ struct SM90_64x160x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47753,7 +48418,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E5M2E5M2_SS_TN +struct MMA_64x176x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47777,6 +48442,7 @@ struct SM90_64x176x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47807,7 +48473,7 @@ struct SM90_64x176x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47821,7 +48487,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F16E5M2E5M2_RS_TN +struct MMA_64x176x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -47845,6 +48511,7 @@ struct SM90_64x176x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47875,7 +48542,7 @@ struct SM90_64x176x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47889,7 +48556,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E5M2E5M2_SS_TN +struct MMA_64x176x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -47924,6 +48591,7 @@ struct SM90_64x176x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -47970,7 +48638,7 @@ struct SM90_64x176x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -47984,7 +48652,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x176x32_F32E5M2E5M2_RS_TN +struct MMA_64x176x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48019,6 +48687,7 @@ struct SM90_64x176x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48065,7 +48734,7 @@ struct SM90_64x176x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48078,7 +48747,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E5M2E5M2_SS_TN +struct MMA_64x192x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48103,6 +48772,7 @@ struct SM90_64x192x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48134,7 +48804,7 @@ struct SM90_64x192x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48146,7 +48816,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F16E5M2E5M2_RS_TN +struct MMA_64x192x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48171,6 +48841,7 @@ struct SM90_64x192x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48202,7 +48873,7 @@ struct SM90_64x192x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48214,7 +48885,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E5M2E5M2_SS_TN +struct MMA_64x192x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48251,6 +48922,7 @@ struct SM90_64x192x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48300,7 +48972,7 @@ struct SM90_64x192x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48312,7 +48984,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x192x32_F32E5M2E5M2_RS_TN +struct MMA_64x192x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48349,6 +49021,7 @@ struct SM90_64x192x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48398,7 +49071,7 @@ struct SM90_64x192x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48411,7 +49084,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E5M2E5M2_SS_TN +struct MMA_64x208x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48437,6 +49110,7 @@ struct SM90_64x208x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48470,7 +49144,7 @@ struct SM90_64x208x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48484,7 +49158,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F16E5M2E5M2_RS_TN +struct MMA_64x208x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48510,6 +49184,7 @@ struct SM90_64x208x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48543,7 +49218,7 @@ struct SM90_64x208x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48557,7 +49232,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E5M2E5M2_SS_TN +struct MMA_64x208x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48596,6 +49271,7 @@ struct SM90_64x208x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48648,7 +49324,7 @@ struct SM90_64x208x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48662,7 +49338,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x208x32_F32E5M2E5M2_RS_TN +struct MMA_64x208x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48701,6 +49377,7 @@ struct SM90_64x208x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48753,7 +49430,7 @@ struct SM90_64x208x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48767,7 +49444,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E5M2E5M2_SS_TN +struct MMA_64x224x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48794,6 +49471,7 @@ struct SM90_64x224x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48828,7 +49506,7 @@ struct SM90_64x224x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48842,7 +49520,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F16E5M2E5M2_RS_TN +struct MMA_64x224x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48869,6 +49547,7 @@ struct SM90_64x224x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -48903,7 +49582,7 @@ struct SM90_64x224x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -48917,7 +49596,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E5M2E5M2_SS_TN +struct MMA_64x224x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48958,6 +49637,7 @@ struct SM90_64x224x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49013,7 +49693,7 @@ struct SM90_64x224x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49027,7 +49707,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x224x32_F32E5M2E5M2_RS_TN +struct MMA_64x224x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49068,6 +49748,7 @@ struct SM90_64x224x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49123,7 +49804,7 @@ struct SM90_64x224x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49137,7 +49818,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E5M2E5M2_SS_TN +struct MMA_64x240x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49165,6 +49846,7 @@ struct SM90_64x240x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49201,7 +49883,7 @@ struct SM90_64x240x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49215,7 +49897,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F16E5M2E5M2_RS_TN +struct MMA_64x240x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49243,6 +49925,7 @@ struct SM90_64x240x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49279,7 +49962,7 @@ struct SM90_64x240x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49293,7 +49976,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E5M2E5M2_SS_TN +struct MMA_64x240x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49336,6 +50019,7 @@ struct SM90_64x240x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49394,7 +50078,7 @@ struct SM90_64x240x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49408,7 +50092,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x240x32_F32E5M2E5M2_RS_TN +struct MMA_64x240x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49451,6 +50135,7 @@ struct SM90_64x240x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49509,7 +50194,7 @@ struct SM90_64x240x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49522,7 +50207,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E5M2E5M2_SS_TN +struct MMA_64x256x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49551,6 +50236,7 @@ struct SM90_64x256x32_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49588,7 +50274,7 @@ struct SM90_64x256x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49600,7 +50286,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F16E5M2E5M2_RS_TN +struct MMA_64x256x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49629,6 +50315,7 @@ struct SM90_64x256x32_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49666,7 +50353,7 @@ struct SM90_64x256x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49678,7 +50365,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E5M2E5M2_SS_TN +struct MMA_64x256x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49723,6 +50410,7 @@ struct SM90_64x256x32_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49784,7 +50472,7 @@ struct SM90_64x256x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -49796,7 +50484,7 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct SM90_64x256x32_F32E5M2E5M2_RS_TN +struct MMA_64x256x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49841,6 +50529,7 @@ struct SM90_64x256x32_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -49902,11 +50591,14 @@ struct SM90_64x256x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + } // namespace cute diff --git a/include/cute/arch/mma_sm90_gmma_sparse.hpp b/include/cute/arch/mma_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000..d05e762e1f --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma_sparse.hpp @@ -0,0 +1,53789 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // GMMA::Major, etc. + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index 61417d8360..3749a9c255 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -31,7 +31,6 @@ #pragma once #include - #include #if defined(__clang__) && defined(__CUDA__) @@ -254,6 +253,28 @@ explode(Fn fn, return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]...); } +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence, + PtrG&& g, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...); +} + // // Utility for exploding tuples into functions // diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 20a0627627..dd6b4e52a0 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -30,16 +30,13 @@ **************************************************************************************************/ #pragma once -#include - -#include - -#include -#include - -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::__CUTE_REQUIRES +#include // cute::is_tuple +#include // cute::is_constant, cute::is_integral +#include // cute::Copy_Traits +#include // cute::TiledMMA namespace cute { @@ -651,10 +648,12 @@ print(ThrCopy const& thr_copy) print(TiledCopy{}); } -template +// TiledCopy to LaTeX TikZ +template CUTE_HOST_DEVICE auto -print_latex(TiledCopy const& copy) +print_latex(TiledCopy const& copy, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); @@ -663,13 +662,15 @@ print_latex(TiledCopy const& copy) layoutD_MN, thrID_D); } -// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread +// MNK Copy Layout to LaTeX TikZ template + class LayoutD, class ThrIDD, + class TikzColorFn = TikzColor_TV> CUTE_HOST_DEVICE void print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutD const& D, ThrIDD const& TD, // (m,n) -> (tid,vid) and tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); @@ -677,33 +678,17 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and assert(size<0>(S) == size<0>(D)); assert(size<1>(S) == size<1>(D)); - char const* latex_header = - "\\documentclass{standalone}\n" - "\\usepackage{tikz}\n" - "\\usetikzlibrary{external}\n" - "\\tikzexternalize\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}",}; - - // Header + // Commented prints printf("%% LayoutS: "); print(S); printf("\n"); printf("%% ThrIDS : "); print(TS); printf("\n"); printf("%% LayoutD: "); print(D); printf("\n"); printf("%% ThrIDD : "); print(TD); printf("\n\n"); - printf(latex_header); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); // S starting at 0,0 for (int i = 0; i < size<0>(S); ++i) { @@ -712,12 +697,22 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and int val_idx = S(i,j) / size(TS); int thr_idx = TS(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), i, j, thr_idx, val_idx); } } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(size<0>(S)), int(size<1>(S))); + // S Labels + for (int i = 0, j = -1; i < size<0>(S); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int i = -1, j = 0; j < size<1>(S); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } // D starting at 0,size<1>(S)+3 for (int i = 0; i < size<0>(D); ++i) { @@ -726,30 +721,26 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and int val_idx = D(i,j) / size(TD); int thr_idx = TD(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), i, j + size<1>(S) + 3, thr_idx, val_idx); } } - - // S Labels - for (int i = 0, j = -1; i < size<0>(S); ++i) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); - } - for (int j = 0, i = -1; j < size<1>(S); ++j) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); - } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, int(size<1>(S)+3), int(size<0>(D)), int(size<1>(D)+size<1>(S)+3)); // D Labels - for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { + for (int i = 0, j = size<1>(D); i < size<0>(D); ++i) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); } - for (int j = 0, i = -1; j < size<1>(D); ++j) { + for (int i = -1, j = 0; j < size<1>(D); ++j) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); } // Footer - printf(latex_footer); + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); } } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm50.hpp b/include/cute/atom/copy_traits_sm50.hpp index 8be0ef7bba..7a693805e6 100644 --- a/include/cute/atom/copy_traits_sm50.hpp +++ b/include/cute/atom/copy_traits_sm50.hpp @@ -39,7 +39,7 @@ namespace cute { template <> -struct Copy_Traits +struct Copy_Traits { // Logical thread id to thread idx (one-thread) using ThrID = Layout<_32>; @@ -55,4 +55,21 @@ struct Copy_Traits using RefLayout = SrcLayout; }; +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_32, _2>>, + Stride,Stride< _1, _256>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index ad4f8675b5..54f76073b1 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -450,7 +450,9 @@ make_im2col_tma_copy_desc( CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; CUtensorMapL2promotion tma_l2Promotion = to_CUtensorMapL2promotion(aux_params.l2promo_); CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_); - CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle)); + TMA::SmemSwizzleBits swizzle_bits = detail::get_tma_swizzle_bits(smem_swizzle); + TMA::SmemSwizzleBase swizzle_base = detail::get_tma_swizzle_base(smem_swizzle); + CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)( &tma_desc, @@ -636,11 +638,11 @@ make_tma_atom_im2col(CopyOp, auto range_c = size<0,0>(tma_layout_vt); auto range_whdn = size<0,1>(tma_layout_vt); - Tensor gtensor_cwhdn = make_tensor(gtensor.data(), - flatten(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.layout()), - basis_get(stride<0,1>(tma_layout_vt), gtensor.layout())))); - + flatten(make_layout(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,0>(tma_layout_vt), gtensor.stride())), + make_layout(basis_get(stride<0,1>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,1>(tma_layout_vt), gtensor.stride()))))); auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc( gtensor_cwhdn, range_c, diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 2238c41897..3738cc3962 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -41,6 +41,7 @@ #include #include + #include namespace cute @@ -241,15 +242,22 @@ struct Copy_Traits // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar CUTE_HOST_DEVICE constexpr Copy_Traits - with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { - return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask}}; + with( + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; } // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) CUTE_HOST_DEVICE constexpr Copy_Traits - with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { - return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask}}; + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; } // Generate the TMA coord tensor @@ -287,7 +295,8 @@ struct Copy_Traits tuple< TmaDescriptor const*, uint64_t*, // smem mbarrier - uint16_t // multicast mask + uint16_t, // multicast mask + uint64_t // cache hint > const opargs_; }; @@ -684,8 +693,10 @@ construct_tma_gbasis(Tensor const& gtensor, // The origin // TMA parameter checking // - CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), - "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + // CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + // "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + CUTE_STATIC_ASSERT_V(size(slayout) == size(cta_v_map), + "TMA requires CTA_Tile and SLayout top-level size equivalence."); #if 0 print("gtensor : "); print(gtensor); print("\n"); @@ -983,7 +994,9 @@ make_tma_copy_desc(Tensor const& gtensor, // The origin CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; // TMA smem swizzle type - CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); + TMA::SmemSwizzleBits swizzle_bits = get_tma_swizzle_bits(swizzle); + TMA::SmemSwizzleBase swizzle_base = get_tma_swizzle_base(swizzle); + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &tma_desc, tma_format, diff --git a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp index bb44a8353d..3286e72b36 100644 --- a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp +++ b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -68,4 +68,26 @@ get_tma_swizzle_bits(Layout const& layout) return get_tma_swizzle_bits(get_swizzle_portion(layout)); } +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBase +get_tma_swizzle_base(Swizzle) +{ + if constexpr (M == 4) { + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_16B; + } + else { + static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); + } +} + +template +TMA::SmemSwizzleBase +get_tma_swizzle_base(Layout const& layout) +{ + return get_tma_swizzle_base(get_swizzle_portion(layout)); +} + } // namespace cute::detail diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 2358dd568f..bf40827436 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -45,11 +45,12 @@ template struct MMA_Atom : MMA_Atom> {}; -template -struct MMA_Atom> - : MMA_Traits +template +struct MMA_Atom> + : MMA_Traits { - using Traits = MMA_Traits; + using MMA_Op = MMAOperation; + using Traits = MMA_Traits; // Element value types from the MMA_Traits using ValTypeD = typename Traits::ValTypeD; @@ -331,7 +332,7 @@ struct TiledMMA : MMA_Atom make_layout(size<2>(AtomShape_MNK{}))); auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) - // Transform the Atom mode from (N,K) to (Thr,Val) + // Transform the Atom mode from (M,K) to (Thr,Val) auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) // Tile the tensor for the Thread @@ -733,18 +734,22 @@ print(ThrMMA const& thr_mma) print(static_cast(thr_mma)); } -template +// MMA Atom to LaTeX TikZ +template CUTE_HOST_DEVICE void -print_latex(MMA_Atom const& mma_atom) +print_latex(MMA_Atom const& mma_atom, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { print_latex(make_tiled_mma(mma_atom)); } -template +// TiledMMA to LaTeX TikZ +template CUTE_HOST_DEVICE void -print_latex(TiledMMA const& mma) +print_latex(TiledMMA const& mma, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { auto layout_and_thrid_C = mma.get_layoutC_MN(); auto layoutC_MN = get<0>(layout_and_thrid_C); @@ -763,71 +768,17 @@ print_latex(TiledMMA const& mma) layoutB_NK, thrID_B); } -// MNK MMA Layout to console printer -template -CUTE_HOST_DEVICE -void -print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx - LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - assert(size<0>(A) == size<0>(C)); - assert(size<0>(B) == size<1>(C)); - assert(size<1>(A) == size<1>(B)); - - int a_width = size<1>(A) * 6 + 4; - - // Print out B (white-shifted) k-by-n - for (int k = 0; k < size<1>(B); ++k) { - // Header - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("+-----"); - printf("+\n"); - // Values - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); - printf("|\n"); - } - // Footer - printf("%*s", a_width, ""); - for (int n = 0; n < size<0>(B); ++n) printf("+-----"); - printf("+\n\n"); - - // Print out A m-by-k and C m-by-n - for (int m = 0; m < size<0>(A); ++m) { - // Header - for (int k = 0; k < size<1>(A); ++k) printf("+-----"); - printf("+ "); - for (int n = 0; n < size<1>(C); ++n) printf("+-----"); - printf("+\n"); - // Values - for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); - printf("| "); - for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); - printf("|\n"); - } - // Footer - for (int k = 0; k < size<1>(A); ++k) printf("+-----"); - printf("+ "); - for (int n = 0; n < size<1>(C); ++n) printf("+-----"); - printf("+\n"); -} - -// MNK MMA Layout to Latex TIKZ -- 8-value color coded by thread +// MNK MMA Layout to LaTeX TikZ template + class LayoutB, class ThrIDB, + class TikzColorFn = TikzColor_TV> CUTE_HOST_DEVICE void print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx - LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB, // (n,k) -> (tid,vid) and tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); @@ -837,35 +788,18 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and assert(size<0>(B) == size<1>(C)); assert(size<1>(A) == size<1>(B)); - char const* latex_header = - "\\documentclass{standalone}\n" - "\\usepackage{tikz}\n" - "\\usetikzlibrary{external}\n" - "\\tikzexternalize\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}"}; - - // Header + // Commented prints printf("%% LayoutC: "); print(C); printf("\n"); printf("%% ThrIDC : "); print(TC); printf("\n"); printf("%% LayoutA: "); print(A); printf("\n"); printf("%% ThrIDA : "); print(TA); printf("\n"); printf("%% LayoutB: "); print(B); printf("\n"); printf("%% ThrIDB : "); print(TB); printf("\n\n"); - - printf(latex_header); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); // C starting at 0,0 for (int m = 0; m < size<0>(C); ++m) { @@ -874,12 +808,15 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and int val_idx = C(m,n) / size(TC); int thr_idx = TC(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), m, n, thr_idx, val_idx); } } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(size<0>(C)), int(size<1>(C))); // A starting at 0,-size<1>(A)-1 for (int m = 0; m < size<0>(A); ++m) { @@ -888,12 +825,22 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and int val_idx = A(m,k) / size(TA); int thr_idx = TA(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), m, k-1-size<1>(A), thr_idx, val_idx); } } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, int(-size<1>(A)-1), int(size<0>(A)), -1); + // A labels + for (int m = 0, k = -1; m < size<0>(A); ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); + } + for (int m = -1, k = 0; k < size<1>(A); ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); + } // B starting at -size<1>(B)-1,0 for (int n = 0; n < size<0>(B); ++n) { @@ -902,30 +849,82 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and int val_idx = B(n,k) / size(TB); int thr_idx = TB(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), k-1-size<1>(B), n, thr_idx, val_idx); } } - - // A labels - for (int m = 0, k = -1; m < size<0>(A); ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); - } - for (int k = 0, m = -1; k < size<1>(A); ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); - } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + int(-size<1>(B)-1), 0, -1, int(size<0>(B))); // B labels - for (int n = 0, k = -1; n < size<0>(B); ++n) { + for (int n = 0, k = -1; n < size<0>(B); ++n) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); } - for (int k = 0, n = -1; k < size<1>(B); ++k) { + for (int n = -1, k = 0; k < size<1>(B); ++k) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); } // Footer - printf(latex_footer); + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +// MNK MMA Layout to console printer +template +CUTE_HOST_DEVICE +void +print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + assert(size<0>(A) == size<0>(C)); + assert(size<0>(B) == size<1>(C)); + assert(size<1>(A) == size<1>(B)); + + int a_width = size<1>(A) * 6 + 4; + + // Print out B (white-shifted) k-by-n + for (int k = 0; k < size<1>(B); ++k) { + // Header + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n"); + // Values + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); + printf("|\n"); + } + // Footer + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n\n"); + + // Print out A m-by-k and C m-by-n + for (int m = 0; m < size<0>(A); ++m) { + // Header + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); + // Values + for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); + printf("| "); + for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); + printf("|\n"); + } + // Footer + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); } // MNK MMA Layout to SVG -- 8-value color coded by thread diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index 8b9ac73642..0994698a87 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -30,23 +30,14 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // cute::Tensor +#include // cute::is_rmem +#include // cute::UniversalFMA +#include // cute::detail::explode namespace cute { -namespace detail { - -template -struct supports_output_scaling { static constexpr bool value = false; }; - -template -struct supports_output_scaling().accumulate_)>> { static constexpr bool value = true; }; - -} // end namespace detail - /** * concept MMA_Traits * { @@ -99,17 +90,27 @@ struct MMA_Traits> using CLayout = Layout>; }; +// Extract an MMA_Op from an MMA_Traits +template +struct MMA_Op {}; + +template +struct MMA_Op> { + using type = MMA_Op_Arg; +}; + // // Generic mma_unpack for any MMA_Traits // -template CUTE_HOST_DEVICE constexpr void -mma_unpack(MMA_Traits const& traits, +mma_unpack(AnyMMATraits const& traits, Tensor & D, Tensor const& A, Tensor const& B, @@ -121,87 +122,47 @@ mma_unpack(MMA_Traits const& traits, static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); // Register value types from the MMA_Operation register arrays + using MMA_Op = typename MMA_Op::type; using RegTypeD = typename remove_extent::type; using RegTypeA = typename remove_extent::type; using RegTypeB = typename remove_extent::type; using RegTypeC = typename remove_extent::type; - using MMATraits = MMA_Traits; - [[maybe_unused]] constexpr int RegNumD = extent::value; + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rD = recast(D); + Tensor rC = recast(C); + + constexpr int RegNumD = extent::value; constexpr int RegNumA = extent::value; constexpr int RegNumB = extent::value; constexpr int RegNumC = extent::value; - Tensor rA = recast(A); - Tensor rB = recast(B); - CUTE_STATIC_ASSERT_V(size(rA) == Int{}); CUTE_STATIC_ASSERT_V(size(rB) == Int{}); - - if constexpr (is_same::value) - { - static_assert(is_same::value, "GMMA C and D value_type must match."); - static_assert(is_same::value, "GMMA C and D layouts must match."); - // assert((void*)&C == (void*)&D); - - Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D - - //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); - - if constexpr (detail::supports_output_scaling::value) { - detail::explode(MMA_Op::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}, - &(traits.accumulate_), seq<0>{}); - } - else { - detail::explode(MMA_Op::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); - } - } - else { - Tensor rD = recast(D); - Tensor rC = recast(C); - - CUTE_STATIC_ASSERT_V(size(rD) == Int{}); - CUTE_STATIC_ASSERT_V(size(rC) == Int{}); - if constexpr (detail::supports_output_scaling::value) { - detail::explode(MMA_Op::fma, - rD, make_int_sequence{}, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}, - &(traits.accumulate_), seq<0>{}); - } - else { - detail::explode(MMA_Op::fma, - rD, make_int_sequence{}, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); - } - } + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); } -// // Accept mutable temporaries -// - -template CUTE_HOST_DEVICE constexpr void -mma_unpack(MMA_Traits const& traits, - Tensor && D, - Tensor const& A, - Tensor const& B, - Tensor const& C) +mma_unpack(AnyMMATraits const& traits, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) { mma_unpack(traits, D, A, B, C); } diff --git a/include/cute/atom/mma_traits_sm90.hpp b/include/cute/atom/mma_traits_sm90.hpp index 437af27b21..b2ced3f878 100644 --- a/include/cute/atom/mma_traits_sm90.hpp +++ b/include/cute/atom/mma_traits_sm90.hpp @@ -41,6 +41,8 @@ namespace cute { //////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// /////////////////////////////////////////////////////////////////////////////// +using SM90_16x8x4_F64F64F64F64_TN = SM90::MMA_16x8x4_F64F64F64F64_TN; + template <> struct MMA_Traits { @@ -59,6 +61,8 @@ struct MMA_Traits Stride,Stride<_16,_8>>>; }; +using SM90_16x8x8_F64F64F64F64_TN = SM90::MMA_16x8x8_F64F64F64F64_TN; + template <> struct MMA_Traits { @@ -77,6 +81,8 @@ struct MMA_Traits Stride,Stride<_16,_8>>>; }; +using SM90_16x8x16_F64F64F64F64_TN = SM90::MMA_16x8x16_F64F64F64F64_TN; + template <> struct MMA_Traits { @@ -99,9 +105,11 @@ struct MMA_Traits //////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// /////////////////////////////////////////////////////////////////////////////////// +using SM90_16x8x4_C64C64C64C64_TN = SM90::MMA_16x8x4_C64C64C64C64_TN; + template <> struct MMA_Traits - : MMA_Traits + : MMA_Traits { using ValTypeD = complex; using ValTypeA = complex; @@ -109,9 +117,11 @@ struct MMA_Traits using ValTypeC = complex; }; +using SM90_16x8x8_C64C64C64C64_TN = SM90::MMA_16x8x8_C64C64C64C64_TN; + template <> struct MMA_Traits - : MMA_Traits + : MMA_Traits { using ValTypeD = complex; using ValTypeA = complex; @@ -119,9 +129,11 @@ struct MMA_Traits using ValTypeC = complex; }; +using SM90_16x8x16_C64C64C64C64_TN = SM90::MMA_16x8x16_C64C64C64C64_TN; + template <> struct MMA_Traits - : MMA_Traits + : MMA_Traits { using ValTypeD = complex; using ValTypeA = complex; diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index e59bbeefc2..74f3d64601 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -30,10 +30,15 @@ **************************************************************************************************/ #pragma once -#include -#include - -#include +#include // cute::smem_ptr_flag +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90_64x8x16_F16F16F16_SS, etc +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static namespace cute { @@ -60,7 +65,7 @@ warpgroup_fence_operand(Tensor& frg) { } } -namespace GMMA { +namespace SM90::GMMA { /////////////////////////////////////////// // Common layouts for GMMA Shared Memory // @@ -99,20 +104,20 @@ template using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); // With GMMA::Major param -template -using Layout_INTER_Atom = typename conditional +using Layout_INTER_Atom = typename conditional, Layout_K_INTER_Atom>::type; -template -using Layout_SW32_Atom = typename conditional +using Layout_SW32_Atom = typename conditional, Layout_K_SW32_Atom>::type; -template -using Layout_SW64_Atom = typename conditional +using Layout_SW64_Atom = typename conditional, Layout_K_SW64_Atom>::type; -template -using Layout_SW128_Atom = typename conditional +using Layout_SW128_Atom = typename conditional, Layout_K_SW128_Atom>::type; @@ -188,7 +193,7 @@ layout_type(Tensor> const&) * auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); * is guaranteed to be accepted by make_gmma_desc for appropriate value_type. */ -template +template CUTE_HOST_DEVICE constexpr GmmaDescriptor make_gmma_desc(Tensor const& tensor) @@ -203,7 +208,7 @@ make_gmma_desc(Tensor const& tensor) GmmaDescriptor desc; // Layout type - constexpr GMMA::LayoutType LAYOUT_TYPE = GMMA::layout_type(u128_tensor); + constexpr LayoutType LAYOUT_TYPE = layout_type(u128_tensor); desc.bitfield.layout_type_ = uint8_t(LAYOUT_TYPE); // Start address (4LSB not included) @@ -214,12 +219,12 @@ make_gmma_desc(Tensor const& tensor) desc.bitfield.base_offset_ = base_offset; // LayoutType meta - constexpr int W = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? 1 : - LAYOUT_TYPE == GMMA::LayoutType::B32 ? 2 : - LAYOUT_TYPE == GMMA::LayoutType::B64 ? 4 : - LAYOUT_TYPE == GMMA::LayoutType::B128 ? 8 : -1; + constexpr int W = LAYOUT_TYPE == LayoutType::INTERLEAVE ? 1 : + LAYOUT_TYPE == LayoutType::B32 ? 2 : + LAYOUT_TYPE == LayoutType::B64 ? 4 : + LAYOUT_TYPE == LayoutType::B128 ? 8 : -1; - if constexpr (MajorMode == GMMA::Major::MN) + if constexpr (MajorMode == Major::MN) { /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form * @@ -228,8 +233,10 @@ make_gmma_desc(Tensor const& tensor) * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) */ - static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{}, // K size - "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits."); + static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{} || // A and B in dense MMA + size<1>(u128_tensor) == Int<(128 / cute::sizeof_bits::value)>{} || // A in sparse MMA + size<1>(u128_tensor) == Int<(512 / cute::sizeof_bits::value)>{}, // B in sparse MMA + "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits for dense or (128|512)/sizeof_bits for sparse."); // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout,_1>{}, Layout,_1>{})); @@ -239,7 +246,7 @@ make_gmma_desc(Tensor const& tensor) CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); // Check canonical mode strides constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); - constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); constexpr uint32_t expected_stride_10 = W; @@ -249,10 +256,10 @@ make_gmma_desc(Tensor const& tensor) constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); - desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; - desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; + desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_01 : stride_11; + desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_11 : stride_01; } - else if constexpr (MajorMode == GMMA::Major::K) + else if constexpr (MajorMode == Major::K) { /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form * @@ -263,8 +270,8 @@ make_gmma_desc(Tensor const& tensor) */ CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); - CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{}, // K size - "Not a canonical GMMA_K Layout: Expected K-size 2 (in units of uint128_t)."); + CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{} || size<1>(u128_tensor) == Int<4>{}, // K size + "Not a canonical GMMA_K Layout: Expected K-size 2 for dense or 4 for sparse (in units of uint128_t)."); // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); @@ -277,7 +284,7 @@ make_gmma_desc(Tensor const& tensor) constexpr uint32_t expected_stride_00 = W; static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); - constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) @@ -286,7 +293,7 @@ make_gmma_desc(Tensor const& tensor) desc.bitfield.stride_byte_offset_ = stride_01; desc.bitfield.leading_byte_offset_ = stride_10; } else { - static_assert(MajorMode != GMMA::Major::MN && MajorMode != GMMA::Major::K, "Unrecognized MajorMode!"); + static_assert(MajorMode != Major::MN && MajorMode != Major::K, "Unrecognized MajorMode!"); } #if 0 @@ -357,21 +364,21 @@ print(DescriptorIterator) { // The GMMA Traits below have custom fragment type flags for their smem desc tensors. // These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. -template +template struct smem_desc : DescriptorIterator {}; -} // end namespace GMMA +} // end namespace SM90::GMMA // Customization point for creating a GMMA::smem_desc Tensor -template -struct MakeTensor> +template +struct MakeTensor> { template CUTE_HOST_DEVICE constexpr auto operator()(Tensor const& smem_tensor) { static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); - return make_tensor(GMMA::DescriptorIterator{GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); } }; @@ -380,7 +387,58 @@ struct MakeTensor> //////////////////////////// MMA_TRAITS /////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -namespace GMMA { +namespace SM90::GMMA { + +// +// Specialized mma_unpack implementation for SM90 GMMA instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + // SM90 GMMA take three arguments rather than four, try to assert C and D are aliased + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + // assert((void*)&C == (void*)&D); + + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} // Accumulator layouts using CLayout_64x8 = Layout,Shape < _2,_2>>, @@ -392,7 +450,7 @@ using CLayout_64x16 = Layout,Shape < _2,_2, _2>>, using CLayout_64x32 = Layout,Shape < _2,_2, _4>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x48 = Layout,Shape < _2,_2, _6>>, +using CLayout_64x48 = Layout,Shape < _2,_2, _6>>, Stride,Stride<_64,_8,_512>>>; using CLayout_64x64 = Layout,Shape < _2,_2, _8>>, @@ -404,31 +462,31 @@ using CLayout_64x80 = Layout,Shape < _2,_2, _10>>, using CLayout_64x96 = Layout,Shape < _2,_2, _12>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x112 = Layout,Shape < _2,_2, Int<14>>>, +using CLayout_64x112 = Layout,Shape < _2,_2, Int<14>>>, Stride,Stride<_64,_8,_512>>>; using CLayout_64x128 = Layout,Shape < _2,_2, _16>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x144 = Layout,Shape < _2,_2, Int<18>>>, +using CLayout_64x144 = Layout,Shape < _2,_2, Int<18>>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x160 = Layout,Shape < _2,_2, Int<20>>>, +using CLayout_64x160 = Layout,Shape < _2,_2, Int<20>>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x176 = Layout,Shape < _2,_2, Int<22>>>, +using CLayout_64x176 = Layout,Shape < _2,_2, Int<22>>>, Stride,Stride<_64,_8,_512>>>; using CLayout_64x192 = Layout,Shape < _2,_2, _24>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x208 = Layout,Shape < _2,_2, Int<26>>>, +using CLayout_64x208 = Layout,Shape < _2,_2, Int<26>>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x224 = Layout,Shape < _2,_2, Int<28>>>, +using CLayout_64x224 = Layout,Shape < _2,_2, Int<28>>>, Stride,Stride<_64,_8,_512>>>; -using CLayout_64x240 = Layout,Shape < _2,_2, Int<30>>>, +using CLayout_64x240 = Layout,Shape < _2,_2, Int<30>>>, Stride,Stride<_64,_8,_512>>>; using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, @@ -438,19 +496,33 @@ using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, using ALayout_64x8 = Layout,Shape < _2, _2>>, Stride,Stride< _8,_256>>>; -// Register source layout for 16-bit value types -using ALayout_64x16 = CLayout_64x16; +// Register source layout for 16-bit (sparse 32-bit) value types +using ALayout_64x16 = CLayout_64x16; -// Register source layout for 8-bit value types -using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, - Stride,Stride<_64,_8,_1024>>>; +// Register source layout for 8-bit (sparse 16-bit) value types +using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_64,_8,_1024>>>; + +// Register source layout for sparse 8-bit value types +using ALayout_64x64 = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_64,_8,_2048>>>; // Shared memory source layouts for any value type template using ABLayout = Layout,Int>>, Stride< _0,Stride< _1,Int>>>; -} // namespace GMMA +} // end namespace SM90::GMMA + +using namespace SM90; + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_SS = SM90::GMMA::MMA_64x8x16_F16F16F16_SS; template struct MMA_Traits> @@ -474,6 +546,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_RS = SM90::GMMA::MMA_64x8x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -495,6 +576,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_SS = SM90::GMMA::MMA_64x16x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -517,6 +607,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_RS = SM90::GMMA::MMA_64x16x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -538,6 +637,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_SS = SM90::GMMA::MMA_64x32x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -560,6 +668,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_RS = SM90::GMMA::MMA_64x32x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -582,6 +699,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_SS = SM90::GMMA::MMA_64x48x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -606,6 +732,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_RS = SM90::GMMA::MMA_64x48x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -628,6 +763,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_SS = SM90::GMMA::MMA_64x64x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -650,6 +794,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_RS = SM90::GMMA::MMA_64x64x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -672,6 +825,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_SS = SM90::GMMA::MMA_64x80x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -696,6 +858,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_RS = SM90::GMMA::MMA_64x80x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -718,6 +889,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_SS = SM90::GMMA::MMA_64x96x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -740,6 +920,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_RS = SM90::GMMA::MMA_64x96x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -762,6 +951,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_SS = SM90::GMMA::MMA_64x112x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -786,6 +984,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_RS = SM90::GMMA::MMA_64x112x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -808,6 +1015,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_SS = SM90::GMMA::MMA_64x128x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -830,6 +1046,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_RS = SM90::GMMA::MMA_64x128x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -852,6 +1077,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_SS = SM90::GMMA::MMA_64x144x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -876,6 +1110,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_RS = SM90::GMMA::MMA_64x144x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -899,6 +1142,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_SS = SM90::GMMA::MMA_64x160x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -923,6 +1175,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_RS = SM90::GMMA::MMA_64x160x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -946,6 +1207,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_SS = SM90::GMMA::MMA_64x176x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -970,6 +1240,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_RS = SM90::GMMA::MMA_64x176x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -992,6 +1271,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_SS = SM90::GMMA::MMA_64x192x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1014,6 +1302,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_RS = SM90::GMMA::MMA_64x192x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1036,6 +1333,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_SS = SM90::GMMA::MMA_64x208x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1060,6 +1366,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_RS = SM90::GMMA::MMA_64x208x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1083,6 +1398,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_SS = SM90::GMMA::MMA_64x224x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1107,6 +1431,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_RS = SM90::GMMA::MMA_64x224x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1130,6 +1463,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_SS = SM90::GMMA::MMA_64x240x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1154,6 +1496,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_RS = SM90::GMMA::MMA_64x240x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1176,6 +1527,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_SS = SM90::GMMA::MMA_64x256x16_F16F16F16_SS; + template struct MMA_Traits> { @@ -1198,6 +1558,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_RS = SM90::GMMA::MMA_64x256x16_F16F16F16_RS; + template struct MMA_Traits> { @@ -1219,6 +1588,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_SS = SM90::GMMA::MMA_64x8x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1241,6 +1619,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_RS = SM90::GMMA::MMA_64x8x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1262,6 +1649,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_SS = SM90::GMMA::MMA_64x16x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1284,6 +1680,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_RS = SM90::GMMA::MMA_64x16x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1305,6 +1710,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_SS = SM90::GMMA::MMA_64x32x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1327,6 +1741,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_RS = SM90::GMMA::MMA_64x32x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1349,6 +1772,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_SS = SM90::GMMA::MMA_64x48x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1373,6 +1805,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_RS = SM90::GMMA::MMA_64x48x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1395,6 +1836,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_SS = SM90::GMMA::MMA_64x64x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1417,6 +1867,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_RS = SM90::GMMA::MMA_64x64x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1439,6 +1898,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_SS = SM90::GMMA::MMA_64x80x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1463,6 +1931,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_RS = SM90::GMMA::MMA_64x80x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1485,6 +1962,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_SS = SM90::GMMA::MMA_64x96x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1507,6 +1993,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_RS = SM90::GMMA::MMA_64x96x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1529,6 +2024,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_SS = SM90::GMMA::MMA_64x112x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1553,6 +2057,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_RS = SM90::GMMA::MMA_64x112x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1575,6 +2088,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_SS = SM90::GMMA::MMA_64x128x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1597,6 +2119,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_RS = SM90::GMMA::MMA_64x128x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1619,6 +2150,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_SS = SM90::GMMA::MMA_64x144x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1643,6 +2183,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_RS = SM90::GMMA::MMA_64x144x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1666,6 +2215,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_SS = SM90::GMMA::MMA_64x160x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1690,6 +2248,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_RS = SM90::GMMA::MMA_64x160x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1713,6 +2280,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_SS = SM90::GMMA::MMA_64x176x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1737,6 +2313,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_RS = SM90::GMMA::MMA_64x176x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1759,6 +2344,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_SS = SM90::GMMA::MMA_64x192x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1781,6 +2375,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_RS = SM90::GMMA::MMA_64x192x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1803,6 +2406,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_SS = SM90::GMMA::MMA_64x208x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1827,6 +2439,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_RS = SM90::GMMA::MMA_64x208x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1850,6 +2471,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_SS = SM90::GMMA::MMA_64x224x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1874,6 +2504,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_RS = SM90::GMMA::MMA_64x224x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1897,6 +2536,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_SS = SM90::GMMA::MMA_64x240x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1921,6 +2569,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_RS = SM90::GMMA::MMA_64x240x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1943,6 +2600,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_SS = SM90::GMMA::MMA_64x256x16_F32F16F16_SS; + template struct MMA_Traits> { @@ -1965,6 +2631,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_RS = SM90::GMMA::MMA_64x256x16_F32F16F16_RS; + template struct MMA_Traits> { @@ -1986,6 +2661,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2008,6 +2692,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2029,6 +2722,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2051,6 +2753,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2072,6 +2783,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2094,6 +2814,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2116,6 +2845,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2140,6 +2878,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2162,6 +2909,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2184,6 +2940,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2206,6 +2971,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2230,6 +3004,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2252,6 +3035,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2274,6 +3066,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2296,6 +3097,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2320,6 +3130,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2342,6 +3161,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2364,6 +3192,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2386,6 +3223,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2410,6 +3256,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2433,6 +3288,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2457,6 +3321,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2480,6 +3353,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2504,6 +3386,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2526,6 +3417,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2548,6 +3448,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2570,6 +3479,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2594,6 +3512,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2617,6 +3544,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2641,6 +3577,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2664,6 +3609,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2688,6 +3642,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2710,6 +3673,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS; + template struct MMA_Traits> { @@ -2732,6 +3704,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS; + template struct MMA_Traits> { @@ -2753,6 +3734,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2775,6 +3763,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2796,6 +3791,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2818,6 +3820,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2839,6 +3848,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2861,6 +3877,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2883,6 +3906,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2907,6 +3937,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2929,6 +3966,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2951,6 +3995,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -2973,6 +4024,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -2997,6 +4055,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3019,6 +4084,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3041,6 +4113,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3063,6 +4142,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3087,6 +4173,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3109,6 +4202,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3131,6 +4231,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3153,6 +4260,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3177,6 +4291,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3200,6 +4321,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3224,7 +4352,14 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN; + +template struct MMA_Traits> { using ValTypeD = float; @@ -3247,6 +4382,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3271,6 +4413,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3293,6 +4442,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3315,6 +4471,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3337,6 +4500,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3361,6 +4531,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3384,6 +4561,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3408,6 +4592,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3431,6 +4622,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3455,6 +4653,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3477,6 +4682,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN; + template struct MMA_Traits> { @@ -3499,6 +4711,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN; + template struct MMA_Traits> { @@ -3520,6 +4739,10 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3542,6 +4765,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3564,6 +4791,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3586,6 +4817,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3608,6 +4843,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3630,6 +4869,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3653,6 +4896,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3677,6 +4924,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3700,6 +4951,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3722,6 +4977,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3745,6 +5004,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3769,6 +5032,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3792,6 +5059,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3814,6 +5085,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3837,6 +5112,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3861,6 +5140,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3884,6 +5167,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3906,6 +5193,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3929,6 +5220,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -3953,6 +5248,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -3977,6 +5276,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4001,6 +5304,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4025,6 +5332,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4049,6 +5360,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4072,6 +5387,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4094,6 +5413,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4117,6 +5440,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4141,6 +5468,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4165,6 +5496,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4189,6 +5524,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4213,6 +5552,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4237,6 +5580,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4260,6 +5607,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN; + template <> struct MMA_Traits { @@ -4282,6 +5633,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4304,6 +5659,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4325,6 +5684,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4346,6 +5709,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4367,6 +5734,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4388,6 +5759,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4409,6 +5784,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4431,6 +5810,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4454,6 +5837,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4476,6 +5863,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4497,6 +5888,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4519,6 +5914,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4542,6 +5941,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4564,6 +5967,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4585,6 +5992,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4607,6 +6018,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4630,6 +6045,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4652,6 +6071,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4673,6 +6096,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4695,6 +6122,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4718,6 +6149,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4741,6 +6176,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4764,6 +6203,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4787,6 +6230,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4810,6 +6257,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4832,6 +6283,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4853,6 +6308,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4875,6 +6334,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4898,6 +6361,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4921,6 +6388,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4944,6 +6415,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -4967,6 +6442,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -4990,6 +6469,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5012,6 +6495,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN; + template <> struct MMA_Traits { @@ -5033,6 +6520,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5054,6 +6545,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5076,6 +6571,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5098,6 +6597,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5120,6 +6623,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5142,6 +6649,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5164,6 +6675,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5187,6 +6702,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5211,6 +6730,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5234,7 +6757,11 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// -template <> + + +using SM90_64x64x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN; + +template <> struct MMA_Traits { using ValTypeD = int32_t; @@ -5256,6 +6783,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5279,6 +6810,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5303,6 +6838,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5326,6 +6865,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5348,6 +6891,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5371,6 +6918,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5395,6 +6946,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5418,6 +6973,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5440,6 +6999,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5463,6 +7026,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5487,6 +7054,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5511,6 +7082,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5535,6 +7110,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5559,6 +7138,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5583,6 +7166,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5606,6 +7193,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5628,6 +7219,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5651,6 +7246,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5675,6 +7274,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5699,6 +7302,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5723,6 +7330,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5747,6 +7358,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5771,6 +7386,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5794,6 +7413,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN; + template <> struct MMA_Traits { @@ -5816,6 +7439,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5838,6 +7465,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -5859,6 +7490,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5880,6 +7515,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -5901,6 +7540,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5922,6 +7565,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -5943,6 +7590,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -5965,6 +7616,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -5988,6 +7643,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6010,6 +7669,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6031,6 +7694,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6053,6 +7720,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6076,6 +7747,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6098,6 +7773,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6119,6 +7798,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6141,6 +7824,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6164,6 +7851,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6186,6 +7877,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6207,6 +7902,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6229,6 +7928,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6252,6 +7955,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6275,6 +7982,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6298,6 +8009,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6321,6 +8036,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6344,6 +8063,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6366,6 +8089,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6387,6 +8114,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6409,6 +8140,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6432,6 +8167,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6455,6 +8194,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6478,6 +8221,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6501,6 +8248,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6524,6 +8275,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6546,6 +8301,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN; + template <> struct MMA_Traits { @@ -6567,6 +8326,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6588,6 +8351,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6610,6 +8377,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6632,6 +8403,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6654,6 +8429,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6676,6 +8455,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6698,6 +8481,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6721,6 +8508,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6745,6 +8536,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6768,6 +8563,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6790,6 +8589,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6813,6 +8616,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6837,6 +8644,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6860,6 +8671,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6882,6 +8697,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6905,6 +8724,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6929,6 +8752,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6952,6 +8779,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -6974,6 +8805,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -6997,6 +8832,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7021,6 +8860,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7045,6 +8888,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7069,6 +8916,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7093,6 +8944,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7117,6 +8972,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7140,6 +8999,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7162,6 +9025,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7185,6 +9052,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7209,6 +9080,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7233,6 +9108,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7257,6 +9136,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7281,6 +9164,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7305,6 +9192,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7328,6 +9219,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN; + template <> struct MMA_Traits { @@ -7350,6 +9245,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7372,6 +9271,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7393,6 +9296,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7414,6 +9321,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7435,6 +9346,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7456,6 +9371,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7477,6 +9396,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7499,6 +9422,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7522,6 +9449,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7544,6 +9475,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7565,6 +9500,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7587,6 +9526,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7610,6 +9553,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7632,6 +9579,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7653,6 +9604,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7675,6 +9630,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7698,6 +9657,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7720,6 +9683,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7741,6 +9708,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7763,6 +9734,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7786,6 +9761,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7809,6 +9788,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7832,6 +9815,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7855,6 +9842,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7878,6 +9869,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7900,6 +9895,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7921,6 +9920,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7943,6 +9946,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -7966,6 +9973,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -7989,6 +10000,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -8012,6 +10027,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8035,6 +10054,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -8058,6 +10081,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8080,6 +10107,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN; + template <> struct MMA_Traits { @@ -8101,6 +10132,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8122,6 +10157,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8144,6 +10183,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8166,6 +10209,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8188,6 +10235,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8210,6 +10261,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8232,6 +10287,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8255,6 +10314,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8279,6 +10342,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8302,6 +10369,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8324,6 +10395,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8347,6 +10422,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8371,6 +10450,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8394,6 +10477,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8416,6 +10503,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8439,6 +10530,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8463,6 +10558,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8486,6 +10585,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8508,6 +10611,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8531,6 +10638,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8555,6 +10666,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8579,6 +10694,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8603,6 +10722,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8627,6 +10750,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8651,6 +10778,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8674,6 +10805,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8696,6 +10831,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8719,6 +10858,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8743,6 +10886,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8767,6 +10914,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8791,6 +10942,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8815,6 +10970,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8839,6 +10998,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8862,6 +11025,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN; + template <> struct MMA_Traits { @@ -8884,6 +11051,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8906,6 +11077,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -8927,6 +11102,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8948,6 +11127,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -8969,6 +11152,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -8990,6 +11177,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9011,6 +11202,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9033,6 +11228,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9056,6 +11255,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x48x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9078,6 +11281,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9099,6 +11306,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9121,6 +11332,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9144,6 +11359,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x80x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9166,6 +11385,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9187,6 +11410,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9209,6 +11436,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9232,6 +11463,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x112x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9254,6 +11489,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9275,6 +11514,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9297,6 +11540,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9320,6 +11567,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x144x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9343,6 +11594,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9366,6 +11621,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x160x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9389,6 +11648,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9412,6 +11675,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x176x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9434,6 +11701,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9455,6 +11726,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9477,6 +11752,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9500,6 +11779,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x208x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9523,6 +11806,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9546,6 +11833,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x224x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9569,6 +11860,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9592,6 +11887,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + + +using SM90_64x240x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9614,6 +11913,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN; + template <> struct MMA_Traits { @@ -9635,6 +11938,10 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN_SATURATE; + template <> struct MMA_Traits { @@ -9656,6 +11963,13 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9678,6 +11992,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9699,6 +12020,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9721,6 +12049,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9742,6 +12077,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9764,6 +12106,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9785,6 +12134,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9807,6 +12163,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9828,6 +12191,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9850,6 +12220,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9871,6 +12248,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9893,6 +12277,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9915,6 +12306,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9939,6 +12337,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -9962,6 +12367,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -9986,6 +12398,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10008,6 +12427,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10028,7 +12454,14 @@ struct MMA_Traits> GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN; template struct MMA_Traits> @@ -10051,6 +12484,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10073,6 +12513,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10095,6 +12542,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10119,6 +12573,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10142,6 +12603,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10166,6 +12634,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10188,6 +12663,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10210,6 +12692,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10231,6 +12720,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10253,6 +12749,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10275,6 +12778,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10299,6 +12809,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10322,6 +12839,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10346,6 +12870,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10368,6 +12899,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10390,6 +12928,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10411,6 +12956,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10433,6 +12985,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10455,6 +13014,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10479,6 +13045,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10502,6 +13075,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10526,6 +13106,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10549,6 +13136,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10573,6 +13167,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10596,6 +13197,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10620,6 +13228,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10643,6 +13258,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10667,6 +13289,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10690,6 +13319,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10714,6 +13350,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10736,6 +13379,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10758,6 +13408,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10779,6 +13436,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10801,6 +13465,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10823,6 +13494,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10847,6 +13525,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10870,6 +13555,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10894,6 +13586,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10917,6 +13616,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10941,6 +13647,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -10964,6 +13677,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -10988,6 +13708,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11011,6 +13738,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -11035,6 +13769,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11058,6 +13799,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -11082,6 +13830,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11104,6 +13859,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -11126,6 +13888,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11147,6 +13916,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN; + template struct MMA_Traits> { @@ -11169,6 +13945,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN; + template struct MMA_Traits> { @@ -11190,6 +13973,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11212,6 +14002,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11233,6 +14030,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11255,6 +14059,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11276,6 +14087,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11298,6 +14116,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11319,6 +14144,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11341,6 +14173,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11362,6 +14201,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11384,6 +14230,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11405,6 +14258,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11427,6 +14287,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11449,6 +14316,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11473,6 +14347,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11496,6 +14377,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11520,6 +14408,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11542,6 +14437,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11564,6 +14466,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11585,6 +14494,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11607,6 +14523,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11629,6 +14552,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11653,6 +14583,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11676,6 +14613,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11700,6 +14644,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11722,6 +14673,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11744,6 +14702,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11765,6 +14730,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11787,6 +14759,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11809,6 +14788,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11833,6 +14819,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11856,6 +14849,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11880,6 +14880,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11902,6 +14909,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11924,6 +14938,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11945,6 +14966,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -11967,6 +14995,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -11989,6 +15024,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12013,6 +15055,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12036,6 +15085,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12060,6 +15116,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12083,6 +15146,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12107,6 +15177,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12130,6 +15207,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12154,6 +15238,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12177,6 +15268,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12201,6 +15299,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12224,6 +15329,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12248,6 +15360,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12270,6 +15389,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12292,6 +15418,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12313,6 +15446,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12335,6 +15475,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12357,6 +15504,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12381,6 +15535,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12404,6 +15565,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12428,6 +15596,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12451,6 +15626,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12475,6 +15657,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12498,6 +15687,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12522,6 +15718,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12545,6 +15748,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12569,6 +15779,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12592,6 +15809,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12616,6 +15840,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12638,6 +15869,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12660,6 +15898,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12681,6 +15926,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN; + template struct MMA_Traits> { @@ -12703,6 +15955,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN; + template struct MMA_Traits> { @@ -12724,6 +15983,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12746,6 +16012,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12767,6 +16040,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12789,6 +16069,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12810,6 +16097,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12832,6 +16126,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12853,6 +16154,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12875,6 +16183,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12894,7 +16209,14 @@ struct MMA_Traits> GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN; template struct MMA_Traits> @@ -12918,6 +16240,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12939,6 +16268,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -12961,6 +16297,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -12983,6 +16326,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13007,6 +16357,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13030,6 +16387,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13054,6 +16418,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13076,6 +16447,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13098,6 +16476,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13119,6 +16504,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13141,6 +16533,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13163,6 +16562,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13187,6 +16593,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13210,6 +16623,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13234,6 +16654,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13256,6 +16683,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13278,6 +16712,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13299,6 +16740,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13321,6 +16769,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13343,6 +16798,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13367,6 +16829,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13390,6 +16859,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13414,6 +16890,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13436,6 +16919,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13458,6 +16948,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13479,6 +16976,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13501,6 +17005,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13523,6 +17034,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13547,6 +17065,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13570,6 +17095,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13594,6 +17126,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13617,6 +17156,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13641,6 +17187,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13664,6 +17217,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13688,6 +17248,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13711,6 +17278,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13735,6 +17309,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13758,6 +17339,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13782,6 +17370,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13804,6 +17399,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13826,6 +17428,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13847,6 +17456,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13869,6 +17485,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13891,6 +17514,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13915,6 +17545,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13938,6 +17575,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -13962,6 +17606,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -13985,6 +17636,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14009,6 +17667,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14032,6 +17697,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14056,6 +17728,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14079,6 +17758,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14103,6 +17789,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14126,6 +17819,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14150,6 +17850,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14172,6 +17879,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14194,6 +17908,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14215,6 +17936,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN; + template struct MMA_Traits> { @@ -14237,6 +17965,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN; + template struct MMA_Traits> { @@ -14258,6 +17993,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14280,6 +18022,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14301,6 +18050,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14323,6 +18079,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14344,6 +18107,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14366,6 +18136,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14387,6 +18164,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14409,6 +18193,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14430,6 +18221,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14452,6 +18250,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14473,6 +18278,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14495,6 +18307,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14517,6 +18336,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14541,6 +18367,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14564,6 +18397,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14588,6 +18428,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14610,6 +18457,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14632,6 +18486,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14653,6 +18514,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14675,6 +18543,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14697,6 +18572,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14721,6 +18603,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14744,6 +18633,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14768,6 +18664,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14790,6 +18693,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14812,6 +18722,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14833,6 +18750,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14855,6 +18779,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14877,6 +18808,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14901,6 +18839,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14924,6 +18869,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14948,6 +18900,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -14970,6 +18929,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -14992,6 +18958,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15013,6 +18986,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15035,6 +19015,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15057,6 +19044,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15081,6 +19075,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15104,6 +19105,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15128,6 +19136,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15151,6 +19166,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15175,6 +19197,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15198,6 +19227,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15222,6 +19258,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15245,6 +19288,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15269,6 +19319,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15292,6 +19349,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15316,6 +19380,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15338,6 +19409,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15360,6 +19438,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15381,6 +19466,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15403,6 +19495,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15425,6 +19524,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15449,6 +19555,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15472,6 +19585,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15496,6 +19616,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15519,6 +19646,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15543,6 +19677,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15566,6 +19707,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15590,6 +19738,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15613,6 +19768,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15637,6 +19799,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15660,6 +19829,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15684,6 +19860,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15706,6 +19889,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15728,6 +19918,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15749,6 +19946,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN; + template struct MMA_Traits> { @@ -15771,6 +19975,13 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN; + template struct MMA_Traits> { @@ -15790,6 +20001,7 @@ struct MMA_Traits> GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000..7252a0ef58 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -0,0 +1,16915 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90::SPARSE::GMMA_64x8x32_F16F16F16_SS, etc +#include // cute::GMMA::Layout_* +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static + +namespace cute { + +namespace SM90::GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major layouts in units of Type and sparsity factor S +template +using Layout_MN_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_INTER_Atom{}.layout_b()))>; +template +using Layout_MN_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW32_Atom{}.layout_b()))>; +template +using Layout_MN_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW64_Atom{}.layout_b()))>; +template +using Layout_MN_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW128_Atom{}.layout_b()))>; + +// K-major layouts in units of Type and sparsity factor S +template +using Layout_K_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_INTER_Atom{}.layout_b()))>; +template +using Layout_K_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW32_Atom{}.layout_b()))>; +template +using Layout_K_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW64_Atom{}.layout_b()))>; +template +using Layout_K_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW128_Atom{}.layout_b()))>; + +// With GMMA::Major param +template +using Layout_INTER_SpAtom = typename conditional, + Layout_K_INTER_SpAtom>::type; +template +using Layout_SW32_SpAtom = typename conditional, + Layout_K_SW32_SpAtom>::type; +template +using Layout_SW64_SpAtom = typename conditional, + Layout_K_SW64_SpAtom>::type; +template +using Layout_SW128_SpAtom = typename conditional, + Layout_K_SW128_SpAtom>::type; + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +template +struct sparse_smem_desc : DescriptorIterator {}; + +} // end namespace SM90::GMMA + +// Customization point for creating a cute::GMMAsparse_smem_desc Tensor +template +struct MakeTensor> +{ + // Note that this is the exact same as cute::GMMAsmem_desc above, plus additional static checks. + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + static_assert(is_sparse::value, "Expected sparse value_type."); + static_assert(is_sparse_ptr::value, "Expected sparse iter."); + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA { + +// Metadata layouts +using ELayout_64x64 = Layout, Shape <_32>>, + Stride, Stride<_64>>>; + +using ELayout_64x32 = Layout, Shape <_16,_2>>, + Stride, Stride<_64,_8>>>; + +using ELayout_64x16 = Layout, Shape < _8,_2>>, + Stride, Stride<_64,_8>>>; + +} // namespace SM90::GMMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A_zipped, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + auto [A, E] = unzip_tensor(A_zipped); + Tensor rA = recast(A); + Tensor rE = recast(E); + Tensor rB = recast(B); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + static_assert(is_same::value, "GMMA DRegisters must have void type."); + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMAOp::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::SPARSE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 35d4f8fdf0..b5cfcf47d3 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -142,21 +142,8 @@ # include #endif -// -// Support -// - -#include - -// -// Basic types -// - -#include - // // Debugging utilities // -#include #include diff --git a/include/cute/container/alignment.hpp b/include/cute/container/alignment.hpp index 4cf60d899f..52e4cbadd9 100644 --- a/include/cute/container/alignment.hpp +++ b/include/cute/container/alignment.hpp @@ -54,17 +54,17 @@ is_byte_aligned(void const* const ptr) # define CUTE_ALIGNAS(n) alignas(n) #endif -template +template struct aligned_struct {}; -template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {}; -template <> struct CUTE_ALIGNAS( 2) aligned_struct< 2> {}; -template <> struct CUTE_ALIGNAS( 4) aligned_struct< 4> {}; -template <> struct CUTE_ALIGNAS( 8) aligned_struct< 8> {}; -template <> struct CUTE_ALIGNAS( 16) aligned_struct< 16> {}; -template <> struct CUTE_ALIGNAS( 32) aligned_struct< 32> {}; -template <> struct CUTE_ALIGNAS( 64) aligned_struct< 64> {}; -template <> struct CUTE_ALIGNAS(128) aligned_struct<128> {}; -template <> struct CUTE_ALIGNAS(256) aligned_struct<256> {}; +template struct CUTE_ALIGNAS( 1) aligned_struct< 1, Child> {}; +template struct CUTE_ALIGNAS( 2) aligned_struct< 2, Child> {}; +template struct CUTE_ALIGNAS( 4) aligned_struct< 4, Child> {}; +template struct CUTE_ALIGNAS( 8) aligned_struct< 8, Child> {}; +template struct CUTE_ALIGNAS( 16) aligned_struct< 16, Child> {}; +template struct CUTE_ALIGNAS( 32) aligned_struct< 32, Child> {}; +template struct CUTE_ALIGNAS( 64) aligned_struct< 64, Child> {}; +template struct CUTE_ALIGNAS(128) aligned_struct<128, Child> {}; +template struct CUTE_ALIGNAS(256) aligned_struct<256, Child> {}; } // end namespace cute diff --git a/include/cute/container/array_aligned.hpp b/include/cute/container/array_aligned.hpp index 9895a8da77..a9d14a1a25 100644 --- a/include/cute/container/array_aligned.hpp +++ b/include/cute/container/array_aligned.hpp @@ -30,8 +30,8 @@ **************************************************************************************************/ #pragma once -#include -#include +#include // CUTE_ALIGNAS +#include // cute::array namespace cute { diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index 1963d8ce7b..6aa26bc9f0 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -181,6 +181,20 @@ struct subbyte_reference } }; +template +CUTE_HOST_DEVICE +void +print(subbyte_reference ref) { + cute::print(ref.get()); +} + +template +CUTE_HOST_DEVICE +void +pretty_print(subbyte_reference ref) { + cute::pretty_print(ref.get()); +} + // // subbyte_iterator // Random-access iterator over subbyte references diff --git a/include/cute/container/bit_field.hpp b/include/cute/container/bit_field.hpp index c5748d84c3..d7fac42a54 100644 --- a/include/cute/container/bit_field.hpp +++ b/include/cute/container/bit_field.hpp @@ -35,9 +35,9 @@ #pragma once -#include - +#include // CUTE_HOST_DEVICE #include // uint_bit_t +#include // cute::is_same namespace cute { diff --git a/include/cute/container/cuda_types.hpp b/include/cute/container/cuda_types.hpp index 8034cb271d..fbc314e543 100644 --- a/include/cute/container/cuda_types.hpp +++ b/include/cute/container/cuda_types.hpp @@ -30,12 +30,8 @@ **************************************************************************************************/ #pragma once -#include - -#include - -#include -#include +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::integral_constant namespace cute { diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index 54d282419e..3123a68d83 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -634,14 +634,23 @@ template CUTE_HOST_DEVICE void print_tuple(Tuple const& t, index_sequence, char s = '(', char e = ')') { using cute::print; - print(s); ((void(print(Is == 0 ? '\0' : ',')), void(print(get(t)))), ...); print(e); + if (sizeof...(Is) == 0) { + print(s); + } else { + ((void(print(Is == 0 ? s : ',')), void(print(get(t)))), ...); + } + print(e); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, index_sequence, char s = '(', char e = ')') { - os << s; (void(os << (Is == 0 ? '\0' : ',') << get(t)), ...); + if (sizeof...(Is) == 0) { + os << s; + } else { + (void(os << (Is == 0 ? s : ',') << get(t)), ...); + } return os << e; } #endif // !defined(__CUDACC_RTC__) diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index 2db934356b..a15f2c1c15 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -30,8 +30,7 @@ **************************************************************************************************/ #pragma once -#include -#include +#include // CUTE_HOST_DEVICE, CUTE_STL_NAMESPACE namespace cute { diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index ceafba0d80..95d06bbdd7 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -30,12 +30,11 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::array +#include // cute::is_tuple +#include // cute::Int +#include // cute::transform /** IntTuple is an integer or a tuple of IntTuples. * This file holds utilities for working with IntTuples, @@ -92,7 +91,7 @@ template using rank_t = decltype(rank(declval())); template -static constexpr int rank_v = rank_t::value; +static constexpr auto rank_v = rank_t::value; // // shape @@ -212,7 +211,7 @@ template using depth_t = decltype(depth(declval())); template -static constexpr int depth_v = depth_t::value; +static constexpr auto depth_v = depth_t::value; // // product @@ -276,7 +275,7 @@ size(IntTuple const& a) } template -static constexpr int size_v = decltype(size(declval()))::value; +static constexpr auto size_v = decltype(size(declval()))::value; // // sum @@ -522,68 +521,31 @@ compatible(IntTupleA const& a, IntTupleB const& b) template using is_compatible = decltype(compatible(declval(), declval())); -/** Test if Shape A is weakly compatible with Shape B: - * there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B) - * Equivalently, the size of Shape B is a multiple of Shape A at each terminal of Shape A. - * weakly_compatible is a partial order on A and B: A <= B +/** Test if Shape A is evenly divided by Tiler B + * @returns Static or dynamic boolean + * @post if result is true_type, then + * size(a) == logical_divide(make_layout(shape(a)),b) will always compile + * and result in true_type. */ -template +template CUTE_HOST_DEVICE constexpr auto -weakly_compatible(IntTupleA const& a, IntTupleB const& b) +evenly_divides(Shape const& a, Tiler const& b) { - if constexpr (is_tuple::value && is_tuple::value) { - if constexpr (tuple_size::value != tuple_size::value) { + if constexpr (is_tuple::value) { + if constexpr (rank_v > rank_v) { return false_type{}; } else { - return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_compatible(x,y); }, + return transform_apply(b, a, [](auto const& x, auto const& y) { return evenly_divides(y,x); }, [](auto const&... z) { return (true_type{} && ... && z); }); } - } else if constexpr (is_integral::value) { - return size(b) % a == Int<0>{}; - } else if constexpr (is_integral::value) { - return false_type{}; } else { - return weakly_compatible(shape(a), shape(b)); + return size(a) == size(b) * size(ceil_div(shape(a), b)); } CUTE_GCC_UNREACHABLE; } -template -using is_weakly_compatible = decltype(weakly_compatible(declval(), declval())); - -/** Test if Shape A is softly compatible with Shape B: - * there exists a Shape C congruent to A such that compatible(shape_div(A,C), B) - * Equivalently, the size of Shape B divides Shape A at each terminal of Shape A. - * softly_compatible is a partial order on A and B: A <= B - */ -template -CUTE_HOST_DEVICE constexpr -auto -softly_compatible(IntTupleA const& a, IntTupleB const& b) -{ - if constexpr (is_tuple::value && is_tuple::value) { - if constexpr (tuple_size::value != tuple_size::value) { - return false_type{}; - } else { - return transform_apply(a, b, [](auto const& x, auto const& y) { return softly_compatible(x,y); }, - [](auto const&... z) { return (true_type{} && ... && z); }); - } - } else if constexpr (is_integral::value) { - return a % size(b) == Int<0>{}; - } else if constexpr (is_integral::value) { - return false_type{}; - } else { - return softly_compatible(shape(a), shape(b)); - } - - CUTE_GCC_UNREACHABLE; -} - -template -using is_softly_compatible = decltype(softly_compatible(declval(), declval())); - /** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> */ template @@ -594,7 +556,7 @@ filter_zeros(IntTupleA const& a, IntTupleB const& b) if constexpr (is_tuple::value) { return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); } else if constexpr (is_constant<0, IntTupleA>::value) { - return Int<1>{}; + return repeat_like(b, Int<1>{}); } else { return b; } @@ -899,92 +861,4 @@ elem_geq(T const& t, U const& u) { return !elem_less(t, u); } -namespace detail { - -/** Increment a (dynamic) coord lexicographically within a shape - * @pre is_congruent::value - * \code - * auto shape = make_shape(1,2,make_shape(2,3),3); - * - * int i = 0; - * for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) { - * std::cout << i++ << ": " << coord << std::endl; - * } - * assert(i == size(shape)); - * \endcode - */ -template -CUTE_HOST_DEVICE constexpr -void -increment(Coord& coord, Shape const& shape) -{ - if constexpr (is_integral::value) { - ++coord; - } else { - increment(get(coord), get(shape)); - if constexpr (I+1 < tuple_size::value) { - if (back(get(coord)) == back(get(shape))) { - back(get(coord)) = 0; - increment(coord, shape); - } - } - } -} - -} // end namespace detail - -struct ForwardCoordIteratorSentinal -{}; - -// A forward iterator for a starting coordinate in a shape's domain, and a shape. -// The starting coordinate may be zero but need not necessarily be. -template -struct ForwardCoordIterator -{ - static_assert(is_congruent::value); - - CUTE_HOST_DEVICE constexpr - Coord const& operator*() const { return coord; } - - CUTE_HOST_DEVICE constexpr - ForwardCoordIterator& operator++() { detail::increment(coord, shape); return *this; } - - // Sentinel for the end of the implied range - CUTE_HOST_DEVICE constexpr - bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); } - CUTE_HOST_DEVICE constexpr - bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); } - CUTE_HOST_DEVICE constexpr - bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); } - // NOTE: These are expensive, avoid use - CUTE_HOST_DEVICE constexpr - bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); } - CUTE_HOST_DEVICE constexpr - bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } - CUTE_HOST_DEVICE constexpr - bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } - - Coord coord; - Shape const& shape; -}; - -// A forward iterator for a coordinate that starts from a provided coordinate -template -CUTE_HOST_DEVICE constexpr -auto -make_coord_iterator(Coord const& coord, Shape const& shape) -{ - return ForwardCoordIterator{coord,shape}; -} - -// A forward iterator for a coordinate that starts from zero -template -CUTE_HOST_DEVICE constexpr -auto -make_coord_iterator(Shape const& shape) -{ - auto coord = repeat_like(shape, int(0)); - return make_coord_iterator(coord, shape); -} - } // end namespace cute diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 60581192b0..bc1b54efbc 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -31,13 +31,13 @@ #pragma once #include - -#include #include #include +#include #include -#include #include +#include +#include // cute::sizeof_bits namespace cute { @@ -660,7 +660,7 @@ template using cosize_t = decltype(cosize(declval())); template -static constexpr int cosize_v = cosize_t::value; +static constexpr auto cosize_v = cosize_t::value; // With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well template @@ -905,6 +905,15 @@ filter_zeros(Layout const& layout) return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); } +// Replace the modes in layout that correspond to a 0 at the terminals of trg_profile with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout, IntTuple const& trg_profile) +{ + return make_layout(filter_zeros(trg_profile, layout.shape()), layout.stride()); +} + // Remove all of the 0-strides and 1-sizes // Return 1-shape if empty template @@ -1350,7 +1359,8 @@ max_common_vector(Layout const& a, /* Return a layout that distributes ShapeB over ShapeA. * * @returns Layout result - * @post softly_compatible(@a b, @a result) + * @post evenly_divides(@a b, size(@a result)) + * @post evenly_divides(@a a, @a result) * @post For all i,j in [0,size(@a result)) with i < j, @a result(i) < @a result(j). Surjective and Ordered. * @post composition(make_layout(shape(@a a)), @a result) is admissible * \code @@ -1726,8 +1736,8 @@ tile_to_shape(Layout const& block, // Assert proper division if constexpr (is_static::value) { - CUTE_STATIC_ASSERT_V(weakly_compatible(block_shape, target_shape), - "tile_to_shape: block shape does not divide the target shape."); + CUTE_STATIC_ASSERT_V(evenly_divides(target_shape, block_shape), + "tile_to_shape: block shape does not divide the target shape."); } auto product_shape = ceil_div(target_shape, block_shape); @@ -1924,92 +1934,97 @@ print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) a printf("+\n"); } -// Generic 2D Layout to Latex printer -- B&W 8-value color coding -template +struct TikzColor_White { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + return "white"; + } +}; + +struct TikzColor_BWx8 { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60", + "black!10", "black!50", "black!30", "black!70"}; + return color_map[idx % 8]; + } +}; + +struct TikzColor_TV { + CUTE_HOST_DEVICE char const* + operator()(int tid, int vid) const { + static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + return color_map[tid % 8]; + } +}; + +// Generic 2D Layout to LaTeX printer +template CUTE_HOST_DEVICE void -print_latex(LayoutA const& layout_a) +print_latex(LayoutA const& layout_a, // (m,n) -> idx + TikzColorFn color = {}) // lambda(idx) -> tikz color string { CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); auto layout = append<2>(layout_a, Layout<_1,_0>{}); - char const* latex_header = - "\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center,font=\\Large}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"black!00", - "black!40", - "black!20", - "black!60", - "black!10", - "black!50", - "black!30", - "black!70"}; - - // Header + // Commented print(layout) printf("%% Layout: "); print(layout); printf("\n"); - - printf(latex_header); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); // Layout for (int i = 0; i < size<0>(layout); ++i) { for (int j = 0; j < size<1>(layout); ++j) { int idx = layout(i,j); - printf("\\node[box,fill=%s] at (%d,%d) {%d};\n", - color_map[idx % 8], - i, j, - idx); + printf("\\node[fill=%s] at (%d,%d) {%d};\n", + color(idx), i, j, idx); } } - + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", + int(size<0>(layout)), int(size<1>(layout))); // Labels - for (int i = 0, j = -1; i < size<0>(layout); ++i) { + for (int i = 0, j = -1; i < size<0>(layout); ++i) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); } - for (int j = 0, i = -1; j < size<1>(layout); ++j) { + for (int i = -1, j = 0; j < size<1>(layout); ++j) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); } // Footer - printf(latex_footer); + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); } -// Generic ThrVal 2D Layout to Latex TIKZ -- 8-value color coded by thread -template +// Generic ThrVal 2D Layout to LaTeX TikZ +template CUTE_HOST_DEVICE void -print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and tid -> thr_idx +print_latex(Layout const& layout, // (m,n) -> (tid,vid) + ThrID const& thr, // tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string { CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); - char const* latex_header = - "\\documentclass[convert]{standalone}\n" - "\\usepackage{tikz}\n\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}"}; - + // Commented prints + printf("%% Layout: "); print(layout); printf("\n"); + printf("%% ThrID : "); print(thr); printf("\n"); // Header - printf("%% layout: "); print(layout); printf("\n"); - printf("%% thrid: "); print(thr); printf("\n\n"); - - printf(latex_header); + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); // Layout for (int i = 0; i < size<0>(layout); ++i) { @@ -2018,13 +2033,15 @@ print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and int val_idx = layout(i,j) / size(thr); int thr_idx = thr(thrid); - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[thr_idx % 8], + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), i, j, thr_idx, val_idx); } } - + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", + int(size<0>(layout)), int(size<1>(layout))); // Labels for (int i = 0, j = -1; i < size<0>(layout); ++i) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); @@ -2034,13 +2051,8 @@ print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and } // Footer - printf(latex_footer); + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); } } // end namespace cute - -// -// Extended Layouts -// - -#include diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index fb62541cb4..3e5f836279 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::tuple +#include // cute::true_type, cute::false_type, cute::Int /* This implements a ComposedLayout of the form * LayoutA o Offset o LayoutB diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 651ff8e887..2e46905719 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -197,7 +197,7 @@ struct ArithmeticTupleIterator ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {} CUTE_HOST_DEVICE constexpr - ArithTuple const& operator*() const { return coord_; } + ArithTuple operator*() const { return coord_; } template CUTE_HOST_DEVICE constexpr @@ -206,7 +206,7 @@ struct ArithmeticTupleIterator template CUTE_HOST_DEVICE constexpr auto operator+(Coord const& c) const { - return ArithmeticTupleIterator(coord_ + c); + return ArithmeticTupleIterator>(coord_ + c); } }; @@ -268,13 +268,13 @@ basis_value(SB const& e) // Apply the N... pack to another Tuple template -CUTE_HOST_DEVICE constexpr auto -basis_get(SB const& e, Tuple const& t) +CUTE_HOST_DEVICE decltype(auto) +basis_get(SB const& e, Tuple&& t) { if constexpr (is_scaled_basis::value) { - return basis_get(e.value(), get(t)); + return basis_get(e.value(), get(static_cast(t))); } else { - return t; + return static_cast(t); } CUTE_GCC_UNREACHABLE; } diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp index 5aa6664a89..7dd9ea5bf0 100644 --- a/include/cute/numeric/complex.hpp +++ b/include/cute/numeric/complex.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include -#include -#include +#include // CUTE_HOST_DEVICE + +#include // cutlass::complexm, cutlass::real, cutlass::imag, cutlass::is_complex namespace cute { diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index 169e3a0e67..571b3e3ed0 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -36,7 +36,9 @@ #include #endif -#include +#include // CUTE_STL_NAMESPACE + +#include // cutlass::int2b_t, cutlass::int4b_t namespace cute { @@ -53,8 +55,8 @@ using CUTE_STL_NAMESPACE::int32_t; using CUTE_STL_NAMESPACE::int64_t; template struct int_bit; -template <> struct int_bit< 2> { using type = cutlass::int2b_t; }; -template <> struct int_bit< 4> { using type = cutlass::int4b_t; }; +template <> struct int_bit< 2> { using type = int2_t; }; +template <> struct int_bit< 4> { using type = int4_t; }; template <> struct int_bit< 8> { using type = int8_t; }; template <> struct int_bit< 16> { using type = int16_t; }; template <> struct int_bit< 32> { using type = int32_t; }; @@ -83,9 +85,9 @@ using CUTE_STL_NAMESPACE::uint64_t; using cutlass::uint128_t; template struct uint_bit; -template <> struct uint_bit< 1> { using type = cutlass::uint1b_t; }; -template <> struct uint_bit< 2> { using type = cutlass::uint2b_t; }; -template <> struct uint_bit< 4> { using type = cutlass::uint4b_t; }; +template <> struct uint_bit< 1> { using type = uint1_t; }; +template <> struct uint_bit< 2> { using type = uint2_t; }; +template <> struct uint_bit< 4> { using type = uint4_t; }; template <> struct uint_bit< 8> { using type = uint8_t; }; template <> struct uint_bit< 16> { using type = uint16_t; }; template <> struct uint_bit< 32> { using type = uint32_t; }; diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index 46863ac286..e447103b99 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -30,10 +30,9 @@ **************************************************************************************************/ #pragma once -#include "cute/util/print.hpp" -#include "cute/util/type_traits.hpp" -#include "cute/numeric/math.hpp" -#include "cutlass/fast_math.h" +#include // cute::max, etc +#include // cute::print +#include // __CUTE_REQUIRES, cute::is_std_integral namespace cute { @@ -65,7 +64,7 @@ struct integral_constant : C { static constexpr T value = v; using value_type = T; // Disambiguate C::operator value_type() - //CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + //CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } }; @@ -406,6 +405,20 @@ conditional_return(false_type, TrueType&&, FalseType&& f) { return static_cast(f); } +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, C const&, C const&) { + return C{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, C const&, C const&) { + return b ? v : u; +} + // TrueType and FalseType must have a common type template CUTE_HOST_DEVICE constexpr @@ -435,7 +448,7 @@ static_value() return Int{}; } else { return Trait::value; - } + } CUTE_GCC_UNREACHABLE; } diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp index 943b004982..1b1432533a 100644 --- a/include/cute/numeric/integral_ratio.hpp +++ b/include/cute/numeric/integral_ratio.hpp @@ -30,11 +30,10 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::false_type, cute::true_type +#include // cute::signum +#include // __CUTE_REQUIRES namespace cute { diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index 6d95165de2..e493a3a953 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -30,9 +30,9 @@ **************************************************************************************************/ #pragma once -#include +#include // CUTE_HOST_DEVICE +#include // __CUTE_REQUIRES -#include #include namespace cute @@ -143,7 +143,7 @@ has_single_bit(T x) { // bit_width( 0b0111 ) = 3 template CUTE_HOST_DEVICE constexpr -T +int bit_width(T x) { static_assert(is_unsigned::value, "Only to be used for unsigned types."); constexpr int N = (numeric_limits::digits == 64 ? 6 : @@ -224,7 +224,7 @@ rotr(T x, int s) { // countl_zero( 0b00011100 ) = 3 template CUTE_HOST_DEVICE constexpr -T +int countl_zero(T x) { return numeric_limits::digits - bit_width(x); } @@ -235,7 +235,7 @@ countl_zero(T x) { // countl_one( 0b11100011 ) = 3 template CUTE_HOST_DEVICE constexpr -T +int countl_one(T x) { return countl_zero(~x); } @@ -246,7 +246,7 @@ countl_one(T x) { // countr_zero( 0b00011100 ) = 2 template CUTE_HOST_DEVICE constexpr -T +int countr_zero(T x) { return x == 0 ? numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB } @@ -257,7 +257,7 @@ countr_zero(T x) { // countr_one( 0b11100011 ) = 2 template CUTE_HOST_DEVICE constexpr -T +int countr_one(T x) { return countr_zero(~x); } @@ -285,7 +285,7 @@ popcount(T x) { // Computes the result of bitwise left-shift template CUTE_HOST_DEVICE constexpr -T +auto shiftl(T x, int s) { return s >= 0 ? (x << s) : (x >> -s); } @@ -293,7 +293,7 @@ shiftl(T x, int s) { // Computes the result of bitwise right-shift template CUTE_HOST_DEVICE constexpr -T +auto shiftr(T x, int s) { return s >= 0 ? (x >> s) : (x << -s); } diff --git a/include/cute/numeric/numeric_types.hpp b/include/cute/numeric/numeric_types.hpp index 02c700254b..07444331ff 100644 --- a/include/cute/numeric/numeric_types.hpp +++ b/include/cute/numeric/numeric_types.hpp @@ -30,12 +30,11 @@ **************************************************************************************************/ #pragma once -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::int2_t, cute::int4_t, etc -#include -#include +#include // cutlass::sizeof_bits +#include // cutlass::float_e4m3_t, cutlass::float_e5m2_t, etc namespace cute { @@ -72,4 +71,65 @@ using cutlass::int4b_t; using cutlass::uint4b_t; using cutlass::bin1_t; -} // end namespace cute + +// +// Print utility +// + +CUTE_HOST_DEVICE +void +print(half_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(bfloat16_t a) { + printf("%f", static_cast(a)); +} + + +CUTE_HOST_DEVICE +void +print(tfloat32_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(float_e4m3_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(float_e5m2_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(bfloat16_t v) { + printf("%*.2f", 8, float(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(half_t v) { + printf("%*.2f", 8, float(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(tfloat32_t v) { + printf("%*.2e", 10, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(float_e4m3_t t) { + printf("%*.2f", 8, static_cast(t)); +} + +CUTE_HOST_DEVICE void +pretty_print(float_e5m2_t t) { + printf("%*.2f", 8, static_cast(t)); +} + +} // namespace cute diff --git a/include/cute/numeric/real.hpp b/include/cute/numeric/real.hpp index f797bc13a1..4ce58dfa18 100644 --- a/include/cute/numeric/real.hpp +++ b/include/cute/numeric/real.hpp @@ -35,6 +35,24 @@ namespace cute { +/// Generic add +template +CUTE_HOST_DEVICE constexpr +void +add(C& c, A const& a, B const& b) +{ + c = a + b; +} + +/// Generic multiply +template +CUTE_HOST_DEVICE constexpr +void +mul(C& c, A const& a, B const& b) +{ + c = a * b; +} + /// Generic fused multiply-add template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 604477a0d3..4cfa129cce 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -30,17 +30,13 @@ **************************************************************************************************/ #pragma once -#include +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include +#include // cute::subbyte_iterator +#include // cute::true_type, cute::false_type +#include // sizeof_bits -#include -#include // sizeof_bits -#include -#include - -#include - -#include -#include namespace cute { @@ -50,6 +46,9 @@ namespace cute // Subbyte Types: uint2_t, uint4_t, etc // Requires construction of a subbyte_iterator in order to properly // resolve each element in byte-addressed memory. +// Sparse Types: sparse_elem +// A type that holds one physical element meant to represent S number of logical elements. +// Requires construction of a sparse_ptr that emulates access to the S logical elements. // template @@ -57,6 +56,11 @@ CUTE_HOST_DEVICE constexpr auto recast_ptr(void* ptr) { + if constexpr (is_sparse::value) { + constexpr int sparsity = NewT::sparsity; + NewT* p = reinterpret_cast(ptr); + return make_sparse_ptr(p); + } else if constexpr (cute::is_subbyte_v) { return subbyte_iterator(ptr); } else { @@ -70,6 +74,11 @@ CUTE_HOST_DEVICE constexpr auto recast_ptr(void const* ptr) { + if constexpr (is_sparse::value) { + constexpr int sparsity = NewT::sparsity; + NewT const* p = reinterpret_cast(ptr); + return make_sparse_ptr(p); + } else if constexpr (cute::is_subbyte_v) { return subbyte_iterator(ptr); } else { diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp index db5d3dcfc4..90ca0ceb6e 100644 --- a/include/cute/pointer_base.hpp +++ b/include/cute/pointer_base.hpp @@ -30,10 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include // sizeof_bits +#include // CUTE_HOST_DEVICE +#include // cute::sizeof_bits +#include // cute::declval, cute::void_t, etc namespace cute { diff --git a/include/cute/pointer_flagged.hpp b/include/cute/pointer_flagged.hpp index 08751eb169..eb8d7e452e 100644 --- a/include/cute/pointer_flagged.hpp +++ b/include/cute/pointer_flagged.hpp @@ -30,15 +30,13 @@ **************************************************************************************************/ #pragma once -#include - -#include // cast_smem_ptr_to_uint - -#include -#include -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::ComposedLayout +#include // cute::make_smem_ptr +#include // cute::is_sparse +#include // cute::make_swizzle_ptr +#include // cute::cast_smem_ptr_to_uint +#include // cute::Int namespace cute { @@ -124,6 +122,47 @@ as_position_independent_swizzle_tensor(Tensor&& tensor) CUTE_GCC_UNREACHABLE; } +// A model of a nullptr sparse_ptr> with B == sizeof_bits::value +// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr +template +struct smem_sparse_ptr_flag_bits : Int<0> {}; + +template +using smem_sparse_ptr_flag = smem_sparse_ptr_flag_bits; + +// A flagged construction method to transform ComposedLayout +// Make a swizzle pointer tensor and check that the intended type size matches +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& ptr, + ComposedLayout,Layout> const& layout) +{ + static_assert(is_smem::value, "Expected smem."); + static_assert(is_sparse_ptr::value, "Expected sparse iter"); + static_assert(is_sparse>::value, "Expected sparse elem"); + static_assert(S == iter_value_t::sparsity, "Expected sparsity S"); + static_assert(B == sizeof_bits::raw_type>::value, "Expected B-bit pointer type"); + return make_tensor(make_swizzle_ptr(ptr, layout.layout_a()), layout.layout_b()); +} + +// NOTE: To preserve smem_ptr_flag_bits under recast ops +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Layout> const& layout) +{ + static_assert(dependent_false, "Not implemented for safety"); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout,Layout> const& layout) +{ + static_assert(dependent_false, "Not implemented for safety"); +} + // // Display utilities // @@ -151,4 +190,10 @@ CUTE_HOST_DEVICE void print(smem_ptr_flag_bits ptr) printf("smem_ptr[%db](unset)", B); } +template +CUTE_HOST_DEVICE void print(smem_sparse_ptr_flag_bits) +{ + printf("smem_sparse<%d>_ptr[%db](unset)", S, B); +} + } // end namespace cute diff --git a/include/cute/pointer_sparse.hpp b/include/cute/pointer_sparse.hpp new file mode 100644 index 0000000000..ccae458650 --- /dev/null +++ b/include/cute/pointer_sparse.hpp @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include // cute::false_type, cute::true_type +#include // cute::ratio + +namespace cute +{ + +// A data type that holds one physical element meant to represent Sparsity number of logical elements +// This class is purposely not compatible with anything -- know what you're doing if you attempt to use it +template +struct sparse_elem +{ + static constexpr int sparsity = Sparsity; + using raw_type = T; + T elem_; + + CUTE_HOST_DEVICE constexpr + explicit sparse_elem(T const& elem = {}) : elem_(elem) {} + + CUTE_HOST_DEVICE constexpr friend bool operator==(sparse_elem const& a, sparse_elem const& b) { return a.elem_ == b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator!=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ != b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator< (sparse_elem const& a, sparse_elem const& b) { return a.elem_ < b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator<=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ <= b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator> (sparse_elem const& a, sparse_elem const& b) { return a.elem_ > b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator>=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ >= b.elem_; } +}; + +template +struct is_sparse : false_type {}; +template +struct is_sparse : is_sparse {}; +template +struct is_sparse> : true_type {}; +template +static constexpr auto is_sparse_v = is_sparse::value; + +// Overload sizeof_bits for sparse_elem. +// Much like subbyte element types, this is the effective number of bits in a sparse_elem +// rather than actual physical bits that may be used in storing one. Also like subbyte element +// types, modified iterators are required to properly index and access sparse_elems. +// +// Defining sizeof_bits like this makes reasonable expressions like N * sizeof_bits_v meaningful +// even when E is subbyte or sparse. However, this also means that sparse_elem can rather easily be +// confused with subbyte elements and special care should be taken with each. +template +struct sizeof_bits> { + // Simple implementation that conforms to sizeof_bits + //static constexpr auto value = sizeof_bits::value / S; + //static_assert(value != 0, "sizeof_bits=0 detected. Sparsity is larger than width."); + //static_assert((sizeof_bits::value % S) == 0, "Width needs to be a multiple of sparsity.") + + // Interesting experiment that allows any sparsity level to be used by potentially presenting + // an integral_ratio rather than size_t. This is valid in most integer expressions as well. + static constexpr auto value = cute::ratio(cute::Int>{}, cute::Int{}); +}; + +// +// sparse_ptr +// + +template +struct is_sparse_ptr : false_type {}; +template +struct is_sparse_ptr> : is_sparse_ptr {}; + +template +struct sparse_ptr : iter_adaptor> +{ + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + // Sanity, for now + static_assert(is_sparse::value, "Enforce sparse value-type"); + static_assert(Sparsity == iter_value_t::sparsity, "Enforce sparsity S"); + static_assert(not is_sparse_ptr::value, "Enforce sparse singleton"); + + template + CUTE_HOST_DEVICE constexpr + sparse_ptr operator+(Index const& i) const { + // Only allow offset by multiples of the sparsity factor, + // else the misalignments become a bug. E.g. (sparse_ptr<8,I>{} + 7) + 7 + // Motivation for subsparse_iterator or generalization of subbyte_iterator? + assert(i % Sparsity == 0); + return {this->get() + i / Sparsity}; + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { + // Allow offset by any value and dereference. + // Not implemented in terms of sparse_ptr::op+() + return *(this->get() + i / Sparsity); + } +}; + +template +struct is_sparse_ptr> : true_type {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_sparse_ptr(Iter const& iter) { + if constexpr (Sparsity == 1) { + return iter; + } else { + return sparse_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(sparse_ptr const& ptr) { + static_assert(not is_sparse::value); + return recast_ptr(ptr.get()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(sparse_ptr ptr) +{ + printf("sparse<%d>_", S); print(ptr.get()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, sparse_ptr ptr) +{ + return os << "sparse<" << S << ">_" << ptr.get(); +} +#endif + +} // end namespace cute diff --git a/include/cute/pointer_swizzle.hpp b/include/cute/pointer_swizzle.hpp index a83b485c8e..720b9b1246 100644 --- a/include/cute/pointer_swizzle.hpp +++ b/include/cute/pointer_swizzle.hpp @@ -30,13 +30,11 @@ **************************************************************************************************/ #pragma once -#include - -#include // iterator_traits -#include - -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include // cute::Swizzle, cute::get_swizzle primary template +#include // cute::iterator_traits +#include // cute::subbyte_iterator /* This implements a swizzle pointer of the form * InvolutionFn o PtrAdd @@ -107,16 +105,14 @@ struct swizzle_ptr : iter_adaptor> } }; -template // Default No-Swizzle -struct get_swizzle { using type = Swizzle<0,4,3>; }; +// +// Helper Function +// template // Found the SwizzleFn struct get_swizzle> { using type = SwizzleFn; }; template // Recurse into anything with a ::iterator struct get_swizzle> : get_swizzle {}; -template -using get_swizzle_t = typename get_swizzle::type; - template CUTE_HOST_DEVICE constexpr swizzle_ptr diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp index 09a02a00e7..f2d31f4e34 100644 --- a/include/cute/stride.hpp +++ b/include/cute/stride.hpp @@ -30,10 +30,16 @@ **************************************************************************************************/ #pragma once -#include -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::__CUTE_REQUIRES +#include // cute::is_tuple +#include // cute::is_integral +#include // cute::seq +#include // cute::divmod +#include // cute::basis_get +#include // cute::identity +#include // cute::fold +#include // cute::is_congruent namespace cute { @@ -433,7 +439,7 @@ compact_order(Shape const& shape, Order const& order) auto flat_order = flatten_to_tuple(order); // Find the largest static element of order auto max_order = cute::fold(flat_order, Int<0>{}, [](auto v, auto order) { - if constexpr (is_constant::value) { + if constexpr (is_constant::value) { return order; } else { return v; @@ -474,4 +480,119 @@ compact_order(Shape const& shape, GenRowMajor const& major) return compact_major(shape); } +// +// Coordinate iterator +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape, Order const& order) +{ + ++basis_get(get<0>(order), coord); + cute::for_each(make_range<1, tuple_size::value>{}, [&](auto i){ + if (basis_get(get(order), coord) == basis_get(get(order), shape)) { + basis_get(get(order), coord) = 0; + ++basis_get(get(order), coord); + } + }); +} + +/** Increment a (dynamic) coord colexicographically within a shape + * @pre is_congruent::value + * \code + * auto shape = make_shape(1,2,make_shape(2,3),3); + * auto coord = repeat_like(shape, 0); + * + * for (int i = 0; i < size(shape); ++i) { + * std::cout << i << ": " << coord << std::endl; + * increment(coord, shape); + * } + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape) +{ + increment(coord, shape, flatten_to_tuple(make_basis_like(shape))); +} + +} // end namespace detail + +struct ForwardCoordIteratorSentinel +{}; + +// A forward iterator for a starting coordinate in a shape's domain, and a shape. +// The starting coordinate may be zero but need not necessarily be. +template +struct ForwardCoordIterator +{ + static_assert(is_congruent::value); + + CUTE_HOST_DEVICE constexpr + Coord const& operator*() const { return coord; } + CUTE_HOST_DEVICE constexpr + ForwardCoordIterator& operator++() { detail::increment(coord, shape, Order{}); return *this; } + // Sentinel for the end of the implied range + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIteratorSentinel const&) const { return basis_get(back(Order{}), coord) == basis_get(back(Order{}), shape); } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIteratorSentinel const&) const { return basis_get(back(Order{}), coord) != basis_get(back(Order{}), shape); } + // NOTE: These are expensive, avoid use + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } + + Coord coord; + Shape const& shape; +}; + +// A forward iterator for a coordinate that starts from a provided coordinate and increments in a prescribed order +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Coord const& coord, Shape const& shape) +{ + static_assert(is_congruent::value); + static_assert(is_congruent::value); + static_assert(is_congruent::value); + auto flat_order = flatten_to_tuple(Order{}); + auto inv_order = transform(make_seq{}, [&](auto i){ return find(flat_order, i); }); + auto basis_order = transform_leaf(inv_order, [&](auto i) { return get(flatten_to_tuple(make_basis_like(shape))); }); + return ForwardCoordIterator{coord,shape}; +} + +// A forward iterator for a coordinate that starts from a provided coordinate and increments colex +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Coord const& coord, Shape const& shape) +{ + static_assert(is_congruent::value); + auto basis_order = flatten_to_tuple(make_basis_like(shape)); + return ForwardCoordIterator{coord,shape}; +} + +// A forward iterator for a coordinate that starts from zero and increments in a prescribed order +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Shape const& shape) +{ + return make_coord_iterator(repeat_like(shape, int(0)), shape); +} + +// A forward iterator for a coordinate that starts from zero and increments colex +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Shape const& shape) +{ + return make_coord_iterator(repeat_like(shape, int(0)), shape); +} + } // end namespace cute diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp index 9ceb0d32b0..52abf856dd 100644 --- a/include/cute/swizzle.hpp +++ b/include/cute/swizzle.hpp @@ -30,13 +30,11 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::is_tuple +#include // cute::constant +#include // cute::max, cute::min +#include // cute::transform_apply namespace cute { @@ -488,4 +486,13 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) } #endif // !defined(__CUDACC_RTC__) +// +// Helper Function +// +template // Default No-Swizzle +struct get_swizzle { using type = Swizzle<0,4,3>; }; + +template +using get_swizzle_t = typename get_swizzle::type; + } // end namespace cute diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index 82e51c79c6..1324360eba 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -30,13 +30,10 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include - -#include -#include // get_swizzle +#include // CUTE_HOST_DEVICE +#include // cute::Layout +#include // cute::ComposedLayout +#include // cute::Swizzle, cute::get_swizzle primary template /* Specialized functionality for a ComposedLayout of the form * InvolutionFn o Offset o LayoutB @@ -57,6 +54,9 @@ namespace cute { +// +// Helper Function +// template struct get_swizzle,Offset,LayoutB>> { using type = Swizzle; }; @@ -193,7 +193,7 @@ make_swizzle_strides(true_type, // 0 Z DC // 1 -Z DC - return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z << Int{}, -(Z << Int{}))...); + return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z * Int<(1 << I)>{}, -Z * Int<(1 << I)>{})...); } template @@ -214,7 +214,7 @@ make_swizzle_strides(false_type, // 0 Y+Z Y-Z // 1 DC DC - return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) << Int{}, (Y-Z) << Int{})...); + return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) * Int<(1 << I)>{}, (Y-Z) * Int<(1 << I)>{})...); } } // end namespace detail @@ -240,16 +240,6 @@ slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout // The portion of the layout that is not yet consumed auto sliced_layout = slice(coord, layout.layout_b()); - // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay - - // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] - // (this also tests that shape/stride of layout compose with swizzle) - auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); - // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) - auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); - // Determine if any active bits collide under the swizzle - auto hit_ZandY = !(swizzle_active_bits & ~layout.layout_a()(swizzle_active_bits)); - // The portion of the layout that we are consuming now auto diced_layout = dice(coord, layout.layout_b()); auto diced_coord = dice(coord, coord); @@ -269,8 +259,16 @@ slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout // If Layout's codomain hits on Y XOR Z, then it's dynamic-normal // If Layout's codomain hits on neither Y NOR Z, then it's static-normal - // Test the sliced layout for hit_X & hit_Y for potential decay - if constexpr (is_constant::value) + // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay + + // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] + // (this also tests that shape/stride of layout compose with swizzle) + auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + [[maybe_unused]] auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); + + // Determine if any active bits collide under the swizzle for potential decay + if constexpr (is_constant<0, decltype(not (swizzle_active_bits & ~swizzle(swizzle_active_bits)))>::value) { // Hits on Y AND Z, so it's not reducible return cute::make_tuple(composition(swizzle, offset_only_zy, sliced_layout), offset_anti_zy); } else @@ -459,7 +457,7 @@ CUTE_HOST_DEVICE constexpr auto max_alignment(Swizzle const&) { - return Int{}; + return Int<1 << M>{}; } template diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index a45cbd0132..3f3335b63d 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -37,7 +37,10 @@ // #include +#include #include +#include + // // Tensor Algorithms // diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp index da0e245636..61eefc5060 100644 --- a/include/cute/tensor_impl.hpp +++ b/include/cute/tensor_impl.hpp @@ -41,18 +41,16 @@ #pragma once -#include - -#include -#include -#include - -#include -#include -#include - -#include -#include +#include // CUTE_HOST_DEVICE +#include // cute::Shape +#include // cute::is_composed_layout +#include // cute::recast_ptr +#include // cute::iterator_traits +#include // cute::array_aligned +#include // cute::array_subbyte +#include // cute::tuple +#include // cute::is_integral +#include // __CUTE_REQUIRES namespace cute { @@ -69,7 +67,7 @@ namespace cute // iterator begin(); // }; -template +template struct ArrayEngine { using Storage = typename conditional<(sizeof_bits::value % 8 == 0), @@ -85,6 +83,24 @@ struct ArrayEngine CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); } }; +// Specialization for sparse_elem tensor allocation/iteration +template +struct ArrayEngine, N> +{ + static_assert(N % S == 0, "Expected a multiple of the sparsity."); + using value_type = sparse_elem; + using Storage = typename conditional<(sizeof_bits::value % 8 == 0), + array_aligned, + array_subbyte>::type; + using iterator = sparse_ptr*>; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + Storage storage_; + + CUTE_HOST_DEVICE constexpr auto begin() const { return recast_ptr(storage_.begin()); } + CUTE_HOST_DEVICE constexpr auto begin() { return recast_ptr(storage_.begin()); } +}; + template struct ViewEngine { @@ -622,6 +638,30 @@ filter_zeros(Tensor&& tensor) { return make_tensor(tensor.data(), filter_zeros(tensor.layout())); } +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor const& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor&& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + // Remove all of the 0-strides and 1-sizes template CUTE_HOST_DEVICE constexpr @@ -755,10 +795,10 @@ auto max_common_vector(Tensor const& a, Tensor const& b) { - using SrcType = typename Tensor::value_type; - using DstType = typename Tensor::value_type; - using SrcRef = typename Tensor::reference; - using DstRef = typename Tensor::reference; + using SrcType = typename SrcEngine::value_type; + using SrcRef = typename SrcEngine::reference; + using DstType = typename DstEngine::value_type; + using DstRef = typename DstEngine::reference; // Determine if vectorization candidates at all if constexpr (// Should be the same value_types, else the copy is also performing a cast @@ -795,10 +835,10 @@ auto max_common_layout(Tensor const& a, Tensor const& b) { - using SrcType = typename Tensor::value_type; - using DstType = typename Tensor::value_type; - using SrcRef = typename Tensor::reference; - using DstRef = typename Tensor::reference; + using SrcType = typename SrcEngine::value_type; + using SrcRef = typename SrcEngine::reference; + using DstType = typename DstEngine::value_type; + using DstRef = typename DstEngine::reference; // Determine if vectorization candidates at all if constexpr (// Should be the same value_types, else the copy is also performing a cast diff --git a/include/cute/tensor_predicate.hpp b/include/cute/tensor_predicate.hpp index 6814647071..9c8a2ba614 100644 --- a/include/cute/tensor_predicate.hpp +++ b/include/cute/tensor_predicate.hpp @@ -30,9 +30,8 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::true_type namespace cute { diff --git a/include/cute/tensor_zip.hpp b/include/cute/tensor_zip.hpp new file mode 100644 index 0000000000..6d70ffc847 --- /dev/null +++ b/include/cute/tensor_zip.hpp @@ -0,0 +1,243 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::tuple + +namespace cute +{ + +// A tuple of Iterators that can be offset asymmetrically +// Note that this only accepts op+(tuple) and op[tuple] +// where each iterator will be offset by its respective index only. +// READ-ONLY for now until cute::tuple can be constructed with references. +template +struct ZipIterator +{ + using value_type = cute::tuple...>; + using element_type = cute::tuple...>; + // NOTE: cute::tuple does not support constructions with references at the moment. + // Consider fixes and/or an implementation of std::forward_as_tuple. + // For now, use a cute::tuple of value_types instead, which makes this Iterator READ-ONLY. + //using reference = cute::tuple...>; + using reference = value_type; + + ZipIterator() = delete; + + CUTE_HOST_DEVICE constexpr + ZipIterator(Iters... iters) + : iters_(iters...) + {} + + CUTE_HOST_DEVICE constexpr + ZipIterator(cute::tuple const& iters) + : iters_(iters) + {} + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return cute::apply(iters_, [](auto&&... args) { return reference(*args...); }); + } + + template + CUTE_HOST_DEVICE constexpr + ZipIterator operator+(cute::tuple const& idxs) const { + static_assert(sizeof...(Index) == sizeof...(Iters), "Expect same number of offsets as iterators."); + return cute::transform(iters_, idxs, [](auto&& iter, auto&& idx) { return iter + idx; }); + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](cute::tuple const& idxs) const { + return *(*this + idxs); + } + + cute::tuple iters_; +}; + +//------------------------------------------------------------------------------ +// type traits + +template +struct is_rmem> : conjunction...> {}; +template +struct is_smem> : conjunction...> {}; +template +struct is_gmem> : conjunction...> {}; +// A tuple of Layouts that operates on each Layout symmetrically +// The Layouts need to have compatible shapes and ranks. +// The ZipLayout presents the intersection of the domain of its component Layouts. +// E.g. all Layouts accept 1D coords and ZipLayout does as well. +// The ZipLayout returns the union of the codomain of its component Layouts. +// E.g. all Layouts return an integer so ZipLayout returns a tuple of integers. +template +struct ZipLayout +{ + static constexpr int rank = (int(0) | ... | Layouts::rank); + + static_assert((is_layout::value && ...), "All template parameters must be layouts"); + static_assert(((Layouts::rank == rank) && ...), "All layouts must have the same rank"); + + CUTE_HOST_DEVICE constexpr + ZipLayout(Layouts const&... layouts) + : layouts_(layouts...) + {} + + CUTE_HOST_DEVICE constexpr + ZipLayout(cute::tuple const& layouts) + : layouts_(layouts) + {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return ZipLayout(cute::transform(layouts_, [&] (auto layout) { return layout(coord); })); + } else { + return cute::transform(layouts_, [&] (auto layout) { return layout(coord); }); + } + + CUTE_GCC_UNREACHABLE; + } + + // op() convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + cute::tuple layouts_; +}; + +template +struct is_layout> : true_type {}; + +// +// make_zip_tensor and unzip_tensor +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_zip_tensor(Tensor const&... tensors) +{ + return make_tensor(ZipIterator(tensors.data()...), + ZipLayout(tensors.layout()...)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +unzip_tensor(Tensor const& tensor) +{ + return cute::transform(tensor.data().iters_, tensor.layout().layouts_, + [](auto iter, auto layout) { return make_tensor(iter, layout); }); +} + +// +// Utilities +// + +template +CUTE_HOST_DEVICE constexpr +auto +rank(ZipLayout const& layouts) +{ + return rank(get<0>(layouts.layouts_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +size(ZipLayout const& layouts) +{ + return size(get<0>(layouts.layouts_)); +} + +// +// Manipulation +// + +// Extend each component layout to rank-N by appending Layout @a x. +template +CUTE_HOST_DEVICE constexpr +auto +append(ZipLayout const& layouts, + Layout const& x = {}) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return append(t, x); })); +} + +// Extend each component layout to rank-N by prepending Layout @a x. +template +CUTE_HOST_DEVICE constexpr +auto +prepend(ZipLayout const& layouts, + Layout const& x = {}) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return prepend(t, x); })); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ZipLayout const& layouts, + Tiler const& tiler) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return logical_divide(t, tiler); })); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ZipLayout const& layouts, + Tiler const& tiler) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return zipped_divide(t, tiler); })); +} + +// Return by calling slice_and_offset and all component layouts. +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& c, ZipLayout const& layouts) +{ + auto result = cute::zip(cute::transform(layouts.layouts_, [&c](auto const& layout) { return slice_and_offset(c, layout); })); + return cute::make_tuple(ZipLayout(get<0>(result)), get<1>(result)); +} + +} // end namespace cute diff --git a/include/cute/underscore.hpp b/include/cute/underscore.hpp index 212f42d7fa..e9d80fe5b5 100644 --- a/include/cute/underscore.hpp +++ b/include/cute/underscore.hpp @@ -30,12 +30,9 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include -#include +#include // CUTE_INLINE_CONSTANT, CUTE_HOST_DEVICE +#include // cute::is_tuple +#include // cute::false_type, cute::true_type namespace cute { diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp index 6463e8684f..6bfe6c0a1e 100644 --- a/include/cute/util/print.hpp +++ b/include/cute/util/print.hpp @@ -30,9 +30,8 @@ **************************************************************************************************/ #pragma once -#include - -#include +#include // CUTE_HOST_DEVICE +#include // cute::is_valid // // CUDA compatible print and printf @@ -156,50 +155,45 @@ print(char const* format) { // pretty printing // -template -CUTE_HOST_DEVICE void -pretty_print(T const& v) { - printf(" "); print(v); -} - CUTE_HOST_DEVICE void -pretty_print(bool const& v) { +pretty_print(bool v) { printf("%*d", 3, int(v)); } CUTE_HOST_DEVICE void -pretty_print(int32_t const& v) { +pretty_print(int32_t v) { printf("%*d", 5, v); } CUTE_HOST_DEVICE void -pretty_print(uint32_t const& v) { +pretty_print(uint32_t v) { printf("%*d", 5, v); } CUTE_HOST_DEVICE void -pretty_print(int64_t const& v) { +pretty_print(int64_t v) { printf("%*lld", 5, static_cast(v)); } CUTE_HOST_DEVICE void -pretty_print(uint64_t const& v) { +pretty_print(uint64_t v) { printf("%*llu", 5, static_cast(v)); } CUTE_HOST_DEVICE void -pretty_print(half_t const& v) { - printf("%*.2f", 8, float(v)); +pretty_print(float v) { + printf("%*.2e", 10, v); } CUTE_HOST_DEVICE void -pretty_print(float const& v) { - printf("%*.2e", 10, v); +pretty_print(double v) { + printf("%*.3e", 11, v); } +template CUTE_HOST_DEVICE void -pretty_print(double const& v) { - printf("%*.3e", 11, v); +pretty_print(T t) { + printf(" "); print(t); } } // end namespace cute diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index f0eb55116d..e663b569c6 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -44,7 +44,7 @@ #include // numeric_limits #endif -#include +#include // CUTE_STL_NAMESPACE namespace cute { @@ -79,6 +79,7 @@ using CUTE_STL_NAMESPACE::is_const_v; using CUTE_STL_NAMESPACE::is_volatile; using CUTE_STL_NAMESPACE::is_volatile_v; +// Defined in cute/numeric/integral_constant.hpp // using CUTE_STL_NAMESPACE::true_type; // using CUTE_STL_NAMESPACE::false_type; @@ -278,14 +279,14 @@ struct conditional_template { // is_any_of // -/// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us -template +// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us +template struct is_any_of { constexpr static bool value = (... || CUTE_STL_NAMESPACE::is_same_v); }; -/// Is true if and only if T is same as (is_same_v) at least one of the types in Us -template +// Is true if and only if T is same as (is_same_v) at least one of the types in Us +template inline constexpr bool is_any_of_v = is_any_of::value; } // end namespace cute diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index cd2d7be3cb..c96897324a 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -93,12 +93,24 @@ class NamedBarrier { NamedBarrier::arrive_and_wait_internal(num_threads_, id_); } + CUTLASS_DEVICE + void arrive_and_wait_unaligned() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_and_wait_internal_unaligned(num_threads_, id_); + } + CUTLASS_DEVICE void arrive() const { // Note: The value of id_ is already the final barrier id (set correctly in the constructor). NamedBarrier::arrive_internal(num_threads_, id_); } + CUTLASS_DEVICE + void arrive_unaligned() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_internal_unaligned(num_threads_, id_); + } + CUTLASS_DEVICE void sync() const { NamedBarrier::arrive_and_wait(); @@ -148,11 +160,23 @@ class NamedBarrier { sync_internal(num_threads, static_cast(reserved_named_barriers)); } + private: CUTLASS_DEVICE static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_and_wait_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -161,12 +185,23 @@ class NamedBarrier { CUTLASS_DEVICE static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } + CUTLASS_DEVICE + static void arrive_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("barrier.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + CUTLASS_DEVICE static void sync_internal(uint32_t num_threads, uint32_t barrier_id) { NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id); @@ -243,6 +278,7 @@ struct ClusterBarrier { "}" : : "r"(arrive_count), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_barrier_init(__LINE__, smem_addr, arrive_count); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -253,6 +289,7 @@ struct ClusterBarrier { static void wait(ValueType const* smem_ptr, uint32_t phase) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_wait(__LINE__, smem_addr, phase); // Arbitrarily large timer value after which try-wait expires and re-tries. uint32_t ticks = 0x989680; asm volatile( @@ -276,6 +313,7 @@ struct ClusterBarrier { static bool test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_test_wait(__LINE__, smem_addr, phase, pred); uint32_t waitComplete; asm volatile( @@ -300,6 +338,7 @@ struct ClusterBarrier { static bool try_wait(ValueType const* smem_ptr, uint32_t phase) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_try_wait(__LINE__, smem_addr, phase); uint32_t waitComplete; asm volatile( @@ -334,6 +373,7 @@ struct ClusterBarrier { : "r"(smem_addr), "r"(cta_id)); } + cutlass::arch::synclog_emit_cluster_barrier_arrive_cluster(__LINE__, smem_addr, cta_id, pred); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -350,6 +390,7 @@ struct ClusterBarrier { "}" : : "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_barrier_arrive(__LINE__, smem_addr); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -426,6 +467,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx(__LINE__, smem_addr, transaction_bytes); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -463,6 +505,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_expect_transaction(__LINE__, smem_addr, transaction_bytes); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -483,6 +526,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_complete_transaction(__LINE__, smem_addr, dst_cta_id, transaction_bytes, pred); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -536,6 +580,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { CUTLASS_DEVICE void fence_barrier_init() { #if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_barrier_init(__LINE__); asm volatile( "{\n\t" "fence.mbarrier_init.release.cluster; \n" @@ -550,6 +595,7 @@ void fence_barrier_init() { CUTLASS_DEVICE void fence_view_async_shared() { #if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); asm volatile ( "{\n\t" "fence.proxy.async.shared::cta; \n" @@ -571,6 +617,7 @@ void cpasync_barrier_arrive(uint64_t const* smem_ptr) { "}" : : "r"(smem_addr)); + cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_addr); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h new file mode 100644 index 0000000000..b0f750063c --- /dev/null +++ b/include/cutlass/arch/config.h @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Definitions for architecture macros +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 0)) + #define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) + #define CUTLASS_ARCH_MMA_SM90_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_ARCH_MMA_SM90A_ENABLED 1 + #endif + #endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 2) + #define CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 Modifiable +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 3)) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED 1 + #endif + #endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 F64 +#if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED 1 + #endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/arch/grid_dependency_control.h b/include/cutlass/arch/grid_dependency_control.h new file mode 100644 index 0000000000..14ef197497 --- /dev/null +++ b/include/cutlass/arch/grid_dependency_control.h @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grid dependent control (GDC) helpers for programmatic dependent launches (PDL). +*/ + +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#ifndef CUTLASS_GDC_ENABLED + #if (defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \ + __CUDACC_VER_MAJOR__ >= 12 && \ + defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_GDC_ENABLED + #endif +#endif + +namespace cutlass { +namespace arch { + +// Issuing the launch_dependents instruction hints a dependent kernel to launch earlier +// launch_dependents doesn't impact the functionality but the performance: +// Launching a dependent kernel too early can compete with current kernels, +// while launching too late can lead to a long latency. +CUTLASS_DEVICE +void launch_dependent_grids() { +#if (defined(CUTLASS_GDC_ENABLED)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +// Issuing the griddepcontrol.wait instruction enforces no global memory access +// prior to this istruction. This ensures the correctness of global memory access +// when launching a dependent kernel earlier. +CUTLASS_DEVICE +void wait_on_dependent_grids() { +#if (defined(CUTLASS_GDC_ENABLED)) + asm volatile("griddepcontrol.wait;"); +#endif +} + +// Enable kernel-level query regarding whether the GDC feature is turned on +#if (defined(CUTLASS_GDC_ENABLED)) +static constexpr bool IsGdcGloballyEnabled = true; +#else +static constexpr bool IsGdcGloballyEnabled = false; +#endif + + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h index acaa819567..cb0ba4b54b 100644 --- a/include/cutlass/arch/memory_sm80.h +++ b/include/cutlass/arch/memory_sm80.h @@ -326,6 +326,8 @@ struct cp_async { "cp.async only supports CacheOperation::Global when access size is 16B."); unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + cutlass::arch::synclog_emit_cp_async(__LINE__, smem_int_ptr, global_ptr, pred_guard, SizeInBytes); + asm volatile( "{\n" " .reg .pred p;\n" @@ -364,6 +366,8 @@ struct cp_async_zfill { unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); int src_in_bytes = (pred_guard ? SizeInBytes : 0); + cutlass::arch::synclog_emit_cp_async_zfill(__LINE__, smem_int_ptr, global_ptr, pred_guard, SizeInBytes); + asm volatile( #if CUTLASS_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), @@ -401,6 +405,8 @@ struct cp_async_nan<16, CacheOperation::Global> { OOB_NAN_F16x2, OOB_NAN_F16x2}; unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + cutlass::arch::synclog_emit_cp_async_nan(__LINE__, smem_int_ptr, global_ptr, pred_guard); + asm volatile( "{\n" " .reg .pred p;\n" @@ -434,6 +440,7 @@ CUTLASS_DEVICE void cp_async_fence() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.commit_group;\n" ::); + cutlass::arch::synclog_emit_cp_async_fence(__LINE__); #endif } @@ -444,6 +451,7 @@ template CUTLASS_DEVICE void cp_async_wait() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + cutlass::arch::synclog_emit_cp_async_wait(__LINE__, N); #endif } @@ -452,6 +460,7 @@ template <> CUTLASS_DEVICE void cp_async_wait<0>() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.wait_all;\n" ::); + cutlass::arch::synclog_emit_cp_async_wait_all(__LINE__); #endif } diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h index d2b167a7ce..1183ee5e05 100644 --- a/include/cutlass/arch/mma_sm90.h +++ b/include/cutlass/arch/mma_sm90.h @@ -43,30 +43,7 @@ #include "mma.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) - #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED - #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)) - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED - #endif - #endif -#endif - -#if (__CUDACC_VER_MAJOR__ >= 12) - #define CUTLASS_ARCH_MMA_SM90_SUPPORTED - #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED)) - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - #define CUTLASS_ARCH_MMA_SM90_ENABLED - #endif - #endif -#endif - -#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) - #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED -#endif +#include "cutlass/arch/config.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h index c1ffbeeb57..d2b434453e 100644 --- a/include/cutlass/arch/reg_reconfig.h +++ b/include/cutlass/arch/reg_reconfig.h @@ -37,9 +37,11 @@ #include "cutlass/cutlass.h" -#if (defined(__CUDA_ARCH__) &&\ - (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +#ifndef CUDA_CTA_RECONFIG_ACTIVATED + #if (__CUDACC_VER_MAJOR__ >= 12 && \ + defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif #endif namespace cutlass { diff --git a/include/cutlass/arch/synclog.hpp b/include/cutlass/arch/synclog.hpp new file mode 100644 index 0000000000..ea683859a3 --- /dev/null +++ b/include/cutlass/arch/synclog.hpp @@ -0,0 +1,1324 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Synchronization event logging for race condition debugging. +*/ + +#pragma once + +#include "cutlass/detail/helper_macros.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) + +constexpr uint32_t synclog_cap = 1 << 26; + +inline std::mutex synclog_mutex; +inline std::vector synclog_buf_list; +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +inline __device__ uint32_t* synclog_buf; +#endif + +CUTLASS_DEVICE +uint32_t* synclog_alloc(uint32_t n) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t* buf = synclog_buf; + if (buf == nullptr) return nullptr; + uint32_t last = atomicAdd(&buf[0], n); + if (last + n < synclog_cap) return buf + last + 1; + if (last >= synclog_cap) atomicAdd(&buf[0], -n); + #endif + return nullptr; +} + +CUTLASS_DEVICE +void synclog_emit_prefix(uint32_t* to, uint32_t header, uint32_t line) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint64_t time64; + asm volatile ( + "mov.u64 %0, %%globaltimer;\n" + : "=l"(time64) : + ); + to[0] = header; + to[1] = line; + to[2] = time64; + to[3] = time64 >> 32; + to[4] = threadIdx.x; + to[5] = threadIdx.y; + to[6] = threadIdx.z; + to[7] = blockIdx.x; + to[8] = blockIdx.y; + to[9] = blockIdx.z; + #endif +} + +constexpr uint32_t synclog_header_none = 0; +constexpr uint32_t synclog_length_prefix = 1 + 1 + 2 + 3 + 3; + +constexpr bool synclog_enable_syncthreads = true; +constexpr uint32_t synclog_header_syncthreads = 1; +constexpr uint32_t synclog_length_syncthreads = synclog_length_prefix + 0; + +constexpr bool synclog_enable_syncwarp = true; +constexpr uint32_t synclog_header_syncwarp = 2; +constexpr uint32_t synclog_length_syncwarp = synclog_length_prefix + 0; + +constexpr bool synclog_enable_named_barrier_arrive_and_wait = true; +constexpr uint32_t synclog_header_named_barrier_arrive_and_wait = 3; +constexpr uint32_t synclog_length_named_barrier_arrive_and_wait = synclog_length_prefix + 2; + +constexpr bool synclog_enable_named_barrier_arrive = true; +constexpr uint32_t synclog_header_named_barrier_arrive = 4; +constexpr uint32_t synclog_length_named_barrier_arrive = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_init = true; +constexpr uint32_t synclog_header_cluster_barrier_init = 5; +constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_wait = 6; +constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_barrier_test_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7; +constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cluster_barrier_try_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8; +constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9; +constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cluster_barrier_arrive = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive = 10; +constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 3; + +constexpr bool synclog_enable_cluster_barrier_invalidate = true; +constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11; +constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 3; + +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 6; + +constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14; +constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15; +constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 6; + +constexpr bool synclog_enable_fence_barrier_init = true; +constexpr uint32_t synclog_header_fence_barrier_init = 16; +constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0; + +constexpr bool synclog_enable_fence_view_async_shared = true; +constexpr uint32_t synclog_header_fence_view_async_shared = 17; +constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_wait = true; +constexpr uint32_t synclog_header_cp_async_wait = 18; +constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_cp_async_wait_all = true; +constexpr uint32_t synclog_header_cp_async_wait_all = 19; +constexpr uint32_t synclog_length_cp_async_wait_all = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_fence = true; +constexpr uint32_t synclog_header_cp_async_fence = 20; +constexpr uint32_t synclog_length_cp_async_fence = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_nan = true; +constexpr uint32_t synclog_header_cp_async_nan = 21; +constexpr uint32_t synclog_length_cp_async_nan = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cp_async_zfill = true; +constexpr uint32_t synclog_header_cp_async_zfill = 22; +constexpr uint32_t synclog_length_cp_async_zfill = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cp_async = true; +constexpr uint32_t synclog_header_cp_async = 23; +constexpr uint32_t synclog_length_cp_async = synclog_length_prefix + 5; + +constexpr bool synclog_enable_tma_load = true; +constexpr uint32_t synclog_header_tma_load = 24; +constexpr uint32_t synclog_length_tma_load = synclog_length_prefix + 4; + +constexpr bool synclog_enable_tma_store = true; +constexpr uint32_t synclog_header_tma_store = 25; +constexpr uint32_t synclog_length_tma_store = synclog_length_prefix + 3; + +constexpr bool synclog_enable_tma_store_arrive = true; +constexpr uint32_t synclog_header_tma_store_arrive = 26; +constexpr uint32_t synclog_length_tma_store_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_tma_store_wait = true; +constexpr uint32_t synclog_header_tma_store_wait = 27; +constexpr uint32_t synclog_length_tma_store_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_arrive = true; +constexpr uint32_t synclog_header_warpgroup_arrive = 28; +constexpr uint32_t synclog_length_warpgroup_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_warpgroup_wait = true; +constexpr uint32_t synclog_header_warpgroup_wait = 29; +constexpr uint32_t synclog_length_warpgroup_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_commit_batch = true; +constexpr uint32_t synclog_header_warpgroup_commit_batch = 30; +constexpr uint32_t synclog_length_warpgroup_commit_batch = synclog_length_prefix + 0; + +constexpr bool synclog_enable_wgmma_reg_smem = true; +constexpr uint32_t synclog_header_wgmma_reg_smem = 31; +constexpr uint32_t synclog_length_wgmma_reg_smem = synclog_length_prefix + 2; + +constexpr bool synclog_enable_wgmma_smem_smem = true; +constexpr uint32_t synclog_header_wgmma_smem_smem = 32; +constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cpasync_barrier_arrive = true; +constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33; +constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 3; + +CUTLASS_DEVICE +bool synclog_condition_emit() { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x%NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; + #else + return 0; + #endif +} + +CUTLASS_DEVICE +bool synclog_condition_print() { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; + #else + return false; + #endif +} + +CUTLASS_DEVICE +void synclog_print_prefix(char const* header, uint32_t at) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t line = synclog_buf[at + 1]; + uint32_t timeLo = synclog_buf[at + 2]; + uint32_t timeHi = synclog_buf[at + 3]; + uint32_t threadIdxX = synclog_buf[at + 4]; + uint32_t threadIdxY = synclog_buf[at + 5]; + uint32_t threadIdxZ = synclog_buf[at + 6]; + uint32_t blockIdxX = synclog_buf[at + 7]; + uint32_t blockIdxY = synclog_buf[at + 8]; + uint32_t blockIdxZ = synclog_buf[at + 9]; + printf( + "%s line=%u time=%lu thread=%u,%u,%u block=%u,%u,%u ", + header, line, + (uint64_t)timeHi << 32 | timeLo, + threadIdxX, threadIdxY, threadIdxZ, + blockIdxX, blockIdxY, blockIdxZ + ); + #endif +} + +CUTLASS_DEVICE +uint64_t synclog_mbarrier_bits(uint32_t smem_addr) { + uint64_t bits = 0; + asm volatile ( + "mbarrier.inval.shared::cta.b64 [%1];\n" + "ld.shared::cta.b64 %0, [%1];\n" + : "=l"(bits) : "r"(smem_addr) + ); + return bits; +} + +CUTLASS_DEVICE +void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) { + CUTLASS_UNUSED(hi); + uint32_t smem_int_ptr = (lo & ((1 << 14) - 1)) << 4; + printf("%s_smem_int_ptr=%u%s", str, smem_int_ptr, sep); +} + +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void synclog_setup() { + #if defined(CUTLASS_ENABLE_SYNCLOG) + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + std::scoped_lock lock(synclog_mutex); + auto fail = [] () { + fprintf(stderr, "synclog_setup() failed\n"); + std::terminate(); + }; + int orig_device = 0; + if (cudaGetDevice(&orig_device) != cudaSuccess) { + fail(); + } + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + fail(); + } + if (synclog_buf_list.size() == 0) { + for (int device = 0; device < device_count; device++) { + uint32_t* buf = 0; + if (cudaSetDevice(device) != cudaSuccess || + cudaMalloc(&buf, synclog_cap * sizeof(uint32_t)) != cudaSuccess) { + fail(); + } + synclog_buf_list.push_back(buf); + } + } + for (int device = 0; device < device_count; device++) { + uint32_t* buf = synclog_buf_list.at(device); + if (cudaSetDevice(device) != cudaSuccess || + cudaMemset(buf, 0, synclog_cap * sizeof(uint32_t)) != cudaSuccess || + cudaMemcpyToSymbol(synclog_buf, &buf, sizeof(buf)) != cudaSuccess) { + fail(); + } + } + if (cudaSetDevice(orig_device) != cudaSuccess) { + fail(); + } + #endif + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncthreads(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncthreads) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncthreads); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncthreads, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncwarp(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncwarp) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncwarp); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncwarp, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive_and_wait( + uint32_t line, + uint32_t num_threads, + uint32_t barrier_id) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive_and_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive_and_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive_and_wait, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive( + uint32_t line, + uint32_t num_threads, + uint32_t barrier_id) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_init( + uint32_t line, + uint32_t smem_addr, + uint32_t arrive_count) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_init, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = arrive_count; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(arrive_count); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_test_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_test_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = pred; + to[synclog_length_prefix + 3] = bits; + to[synclog_length_prefix + 4] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_try_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_try_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive_cluster( + uint32_t line, + uint32_t smem_addr, + uint32_t cta_id, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = cta_id; + to[synclog_length_prefix + 2] = pred; + to[synclog_length_prefix + 3] = bits; + to[synclog_length_prefix + 4] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_invalidate( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_invalidate) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes, + uint32_t cta_id, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = cta_id; + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = bits; + to[synclog_length_prefix + 5] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_expect_transaction( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_complete_transaction( + uint32_t line, + uint32_t smem_addr, + uint32_t dst_cta_id, + uint32_t transaction_bytes, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = dst_cta_id; + to[synclog_length_prefix + 2] = transaction_bytes; + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = bits; + to[synclog_length_prefix + 5] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(dst_cta_id); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_barrier_init(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_barrier_init, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_view_async_shared(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_view_async_shared) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_view_async_shared); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_view_async_shared, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait( + uint32_t line, + uint32_t n) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait, line); + to[synclog_length_prefix + 0] = n; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait_all(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait_all) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait_all); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait_all, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_fence(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_fence) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_fence); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_fence, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_nan( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_nan) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_nan); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_nan, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_zfill( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred, + uint32_t size) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_zfill) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_zfill); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_zfill, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred, + uint32_t size) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_load( + uint32_t line, + uint64_t gmem_int_desc, + uint32_t smem_int_mbar, + uint32_t smem_int_ptr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_load) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_load); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_load, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_mbar; + to[synclog_length_prefix + 3] = smem_int_ptr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_mbar); + CUTLASS_UNUSED(smem_int_ptr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store( + uint32_t line, + uint64_t gmem_int_desc, + uint32_t smem_int_ptr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_ptr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_ptr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_arrive(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_arrive, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_wait( + uint32_t line, + uint32_t count) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_wait, line); + to[synclog_length_prefix + 0] = count; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(count); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_arrive( + uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_arrive, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_wait( + uint32_t line, + uint32_t n) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_wait, line); + to[synclog_length_prefix + 0] = n; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_commit_batch( + uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_commit_batch) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_commit_batch); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_commit_batch, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_reg_smem( + uint32_t line, + uint64_t desc_b) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_reg_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_reg_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_reg_smem, line); + to[synclog_length_prefix + 0] = desc_b; + to[synclog_length_prefix + 1] = desc_b >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_b); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_smem_smem( + uint32_t line, + uint64_t desc_a, + uint64_t desc_b) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_smem_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_smem_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_smem_smem, line); + to[synclog_length_prefix + 0] = desc_a; + to[synclog_length_prefix + 1] = desc_a >> 32; + to[synclog_length_prefix + 2] = desc_b; + to[synclog_length_prefix + 3] = desc_b >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_a); + CUTLASS_UNUSED(desc_b); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cpasync_barrier_arrive( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cpasync_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +#if !defined(CUTLASS_ENABLE_SYNCLOG) +CUTLASS_DEVICE +#elif defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +static __attribute__((__noinline__)) __device__ +#else +static __attribute__((__noinline__)) +#endif +void synclog_print() { + #if defined(CUTLASS_ENABLE_SYNCLOG) + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + if (synclog_buf == nullptr || !synclog_condition_print()) { + return; + } + printf("synclog start\n"); + for (uint32_t at = 1; at < synclog_cap; ) { + uint32_t header = synclog_buf[at]; + if (header == synclog_header_none) { + break; + } + printf("synclog at %u: ", at); + if constexpr (synclog_enable_syncthreads) { + if (header == synclog_header_syncthreads) { + synclog_print_prefix("syncthreads", at); + at += synclog_length_syncthreads; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_syncwarp) { + if (header == synclog_header_syncwarp) { + synclog_print_prefix("syncwarp", at); + at += synclog_length_syncwarp; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive_and_wait) { + if (header == synclog_header_named_barrier_arrive_and_wait) { + synclog_print_prefix("named_barrier_arrive_and_wait", at); + at += synclog_length_named_barrier_arrive_and_wait; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive) { + if (header == synclog_header_named_barrier_arrive) { + synclog_print_prefix("named_barrier_arrive", at); + at += synclog_length_named_barrier_arrive; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_init) { + if (header == synclog_header_cluster_barrier_init) { + synclog_print_prefix("cluster_barrier_init", at); + at += synclog_length_cluster_barrier_init; + printf("smem_addr=%u arrive_count=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_wait) { + if (header == synclog_header_cluster_barrier_wait) { + synclog_print_prefix("cluster_barrier_wait", at); + at += synclog_length_cluster_barrier_wait; + printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_test_wait) { + if (header == synclog_header_cluster_barrier_test_wait) { + synclog_print_prefix("cluster_barrier_test_wait", at); + at += synclog_length_cluster_barrier_test_wait; + printf("smem_addr=%u phase=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_try_wait) { + if (header == synclog_header_cluster_barrier_try_wait) { + synclog_print_prefix("cluster_barrier_try_wait", at); + at += synclog_length_cluster_barrier_try_wait; + printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive_cluster) { + if (header == synclog_header_cluster_barrier_arrive_cluster) { + synclog_print_prefix("cluster_barrier_arrive_cluster", at); + at += synclog_length_cluster_barrier_arrive_cluster; + printf("smem_addr=%u cta_id=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive) { + if (header == synclog_header_cluster_barrier_arrive) { + synclog_print_prefix("cluster_barrier_arrive", at); + at += synclog_length_cluster_barrier_arrive; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_invalidate) { + if (header == synclog_header_cluster_barrier_invalidate) { + synclog_print_prefix("cluster_barrier_invalidate", at); + at += synclog_length_cluster_barrier_invalidate; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx; + printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster; + printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_expect_transaction) { + if (header == synclog_header_cluster_transaction_barrier_expect_transaction) { + synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at); + at += synclog_length_cluster_transaction_barrier_expect_transaction; + printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_complete_transaction) { + if (header == synclog_header_cluster_transaction_barrier_complete_transaction) { + synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at); + at += synclog_length_cluster_transaction_barrier_complete_transaction; + printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_fence_barrier_init) { + if (header == synclog_header_fence_barrier_init) { + synclog_print_prefix("fence_barrier_init", at); + at += synclog_length_fence_barrier_init; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_fence_view_async_shared) { + if (header == synclog_header_fence_view_async_shared) { + synclog_print_prefix("fence_view_async_shared", at); + at += synclog_length_fence_view_async_shared; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait) { + if (header == synclog_header_cp_async_wait) { + synclog_print_prefix("cp_async_wait", at); + at += synclog_length_cp_async_wait; + printf("n=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait_all) { + if (header == synclog_header_cp_async_wait_all) { + synclog_print_prefix("cp_async_wait_all", at); + at += synclog_length_cp_async_wait_all; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_fence) { + if (header == synclog_header_cp_async_fence) { + synclog_print_prefix("cp_async_fence", at); + at += synclog_length_cp_async_fence; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_nan) { + if (header == synclog_header_cp_async_nan) { + synclog_print_prefix("cp_async_nan", at); + at += synclog_length_cp_async_nan; + uint64_t gmem_addr = synclog_buf[at-3]; + gmem_addr += (uint64_t)synclog_buf[at-2] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u\n", synclog_buf[at-4], gmem_addr, synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_zfill) { + if (header == synclog_header_cp_async_zfill) { + synclog_print_prefix("cp_async_zfill", at); + at += synclog_length_cp_async_zfill; + uint64_t gmem_addr = synclog_buf[at-4]; + gmem_addr += (uint64_t)synclog_buf[at-3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async) { + if (header == synclog_header_cp_async) { + synclog_print_prefix("cp_async", at); + at += synclog_length_cp_async; + uint64_t gmem_addr = synclog_buf[at-4]; + gmem_addr += (uint64_t)synclog_buf[at-3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_load) { + if (header == synclog_header_tma_load) { + synclog_print_prefix("tma_load", at); + at += synclog_length_tma_load; + uint64_t gmem_int_desc = synclog_buf[at-4]; + gmem_int_desc += (uint64_t)synclog_buf[at-3] << 32; + printf("gmem_int_desc=%llu smem_int_mbar=%u smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_store) { + if (header == synclog_header_tma_store) { + synclog_print_prefix("tma_store", at); + at += synclog_length_tma_store; + uint64_t gmem_int_desc = synclog_buf[at-3]; + gmem_int_desc += (uint64_t)synclog_buf[at-2] << 32; + printf("gmem_int_desc=%llu smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_store_arrive) { + if (header == synclog_header_tma_store_arrive) { + synclog_print_prefix("tma_store_arrive", at); + at += synclog_length_tma_store_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_tma_store_wait) { + if (header == synclog_header_tma_store_wait) { + synclog_print_prefix("tma_store_wait", at); + at += synclog_length_tma_store_wait; + printf("count=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_arrive) { + if (header == synclog_header_warpgroup_arrive) { + synclog_print_prefix("warpgroup_arrive", at); + at += synclog_length_warpgroup_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_warpgroup_wait) { + if (header == synclog_header_warpgroup_wait) { + synclog_print_prefix("warpgroup_wait", at); + at += synclog_length_warpgroup_wait; + printf("n=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_commit_batch) { + if (header == synclog_header_warpgroup_commit_batch) { + synclog_print_prefix("warpgroup_commit_batch", at); + at += synclog_length_warpgroup_commit_batch; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_reg_smem) { + if (header == synclog_header_wgmma_reg_smem) { + synclog_print_prefix("wgmma_reg_smem", at); + at += synclog_length_wgmma_reg_smem; + synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_smem_smem) { + if (header == synclog_header_wgmma_smem_smem) { + synclog_print_prefix("wgmma_smem_smem", at); + at += synclog_length_wgmma_smem_smem; + synclog_print_wgmma_desc("desc_a", synclog_buf[at-4], synclog_buf[at-3], " "); + synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cpasync_barrier_arrive) { + if (header == synclog_header_cpasync_barrier_arrive) { + synclog_print_prefix("cpasync_barrier_arrive", at); + at += synclog_length_cpasync_barrier_arrive; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + asm volatile ("brkpt;\n" ::); + } + if (synclog_buf[0] >= synclog_cap) { + printf( + "synclog was truncated (exceeded capacity of %lu bytes)\n", + (synclog_cap - 1) * sizeof(uint32_t) + ); + } + printf("synclog end\n"); + #endif + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncthreads +#define __syncthreads() do {\ + cutlass::arch::synclog_emit_syncthreads(__LINE__);\ + __syncthreads();\ +} while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncwarp +#define __syncwarp(...) do {\ + cutlass::arch::synclog_emit_syncwarp(__LINE__);\ + __syncwarp(__VA_ARGS__);\ +} while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 499d45c724..62e9469497 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/functional.h" #include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -49,6 +50,23 @@ template < > struct Array; +namespace detail { + +template +struct is_Array : platform::false_type {}; + +template < + typename T, + int N, + bool RegisterSized +> +struct is_Array > : platform::true_type {}; + +template +constexpr bool is_Array_v = is_Array::value; + +} // namespace detail + //////////////////////////////////////////////////////////////////////////////////////////////////// /// Defines the size of an Array<> in bits @@ -803,111 +821,14 @@ struct reciprocal_approximate_ftz> { } }; -template -struct maximum, false> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct maximum, true> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct minimum, false> { - - CUTLASS_HOST_DEVICE - static T scalar_op(T const &lhs, T const &rhs) { - return (rhs < lhs ? rhs : lhs); - } +template +struct maximum, PropagateNaN> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; - minimum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -921,7 +842,7 @@ struct minimum, false> { Array operator()(Array const &lhs, T const &scalar) const { Array result; - minimum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -935,7 +856,7 @@ struct minimum, false> { Array operator()(T const &scalar, Array const &rhs) const { Array result; - minimum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -946,8 +867,8 @@ struct minimum, false> { } }; -template -struct minimum, true> { +template +struct minimum, PropagateNaN> { CUTLASS_HOST_DEVICE static T scalar_op(T const &lhs, T const &rhs) { @@ -958,7 +879,7 @@ struct minimum, true> { Array operator()(Array const &lhs, Array const &rhs) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -972,7 +893,7 @@ struct minimum, true> { Array operator()(Array const &lhs, T const &scalar) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -986,7 +907,7 @@ struct minimum, true> { Array operator()(T const &scalar, Array const &rhs) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -2030,8 +1951,8 @@ struct multiply_add_relu0, Array, Array> } }; -template -struct minimum, false> { +template +struct minimum, PropagateNaN> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; @@ -2043,25 +1964,27 @@ struct minimum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_ptr[i]) + : __hmin2(lhs_ptr[i], rhs_ptr[i]); } if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmin( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1]); + __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) + : __hmin(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else + minimum mn; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); + result[i] = mn(lhs[i],rhs[i]); } #endif @@ -2079,24 +2002,26 @@ struct minimum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_pair, rhs_ptr[i]) + : __hmin2(lhs_pair, rhs_ptr[i]); } if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmin( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); + __half d_residual = PropagateNaN ? __hmin_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) + : __hmin(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else + minimum mn; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (rhs[i] < lhs ? rhs[i] : lhs); + result[i] = mn(lhs, rhs[i]); } #endif @@ -2114,24 +2039,26 @@ struct minimum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_pair) + : __hmin2(lhs_ptr[i], rhs_pair); } if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half d_residual = __hmin( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); + __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) + : __hmin(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else + minimum mn; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (rhs < lhs[i] ? rhs : lhs[i]); + result[i] = mn(lhs[i], rhs); } #endif @@ -2139,8 +2066,8 @@ struct minimum, false> { } }; -template -struct maximum, false> { +template +struct maximum, PropagateNaN> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; @@ -2152,25 +2079,27 @@ struct maximum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_ptr[i]) + : __hmax2(lhs_ptr[i], rhs_ptr[i]); } if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmax( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1]); + __half d_residual = PropagateNaN ? __hmax(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) + : __hmax_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else + maximum mx; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); + result[i] = mx(lhs[i], rhs[i]); } #endif @@ -2188,24 +2117,26 @@ struct maximum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_pair, rhs_ptr[i]) + : __hmax2(lhs_pair, rhs_ptr[i]); } if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmax( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); + __half d_residual = PropagateNaN ? __hmax_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) + : __hmax(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else + maximum mx; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (lhs < rhs[i] ? rhs[i] : lhs); + result[i] = mx(lhs, rhs[i]); } #endif @@ -2223,24 +2154,26 @@ struct maximum, false> { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_pair) + : __hmax2(lhs_ptr[i], rhs_pair); } if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half d_residual = __hmax( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); + __half d_residual = PropagateNaN ? __hmax_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) + : __hmax(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else + maximum mx; + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - result[i] = (lhs[i] < rhs ? rhs : lhs[i]); + result[i] = mx(lhs[i], rhs); } #endif diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index 50506c73be..5af6d3ab80 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -190,6 +190,12 @@ struct alignas(2) bfloat16_t { return (float(*this) != 0.0f); } + /// Bitcasts to CUDA's bf16 type + CUTLASS_DEVICE + __nv_bfloat16 to_nv_bfloat16() const { + return reinterpret_cast<__nv_bfloat16 const &>(storage); + } + /// Obtains raw bits CUTLASS_HOST_DEVICE uint16_t raw() const { @@ -321,9 +327,9 @@ bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { // /////////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__CUDACC_RTC__) namespace std { -#if !defined(__CUDACC_RTC__) /// Numeric limits template <> struct numeric_limits { @@ -378,9 +384,78 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } }; -#endif } // namespace std +#endif + +namespace cutlass { +namespace platform { + +/// Forward Declaration +template +struct numeric_limits; + +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; +#if !defined(__CUDACC_RTC__) + static std::float_denorm_style const has_denorm = std::denorm_present; +#endif + static bool const has_denorm_loss = true; +#if !defined(__CUDACC_RTC__) + static std::float_round_style const round_style = std::round_to_nearest; +#endif + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 7; + + /// Least positive value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } + + /// Minimum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } +}; + +} // namespace platform +} // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -394,114 +469,190 @@ namespace cutlass { CUTLASS_HOST_DEVICE bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __heq(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) == float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hne(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) != float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hlt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) < float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hle(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) <= float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hgt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) > float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hge(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) >= float(rhs); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) + float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator-(bfloat16_t const& lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hneg(lhs.to_nv_bfloat16())); +#else return bfloat16_t(-float(lhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) - float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) * float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) / float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) + float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) - float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) * float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) / float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator++(bfloat16_t & lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); ++tmp; lhs = bfloat16_t(tmp); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator--(bfloat16_t & lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); --tmp; lhs = bfloat16_t(tmp); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t operator++(bfloat16_t & lhs, int) { bfloat16_t ret(lhs); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); tmp++; lhs = bfloat16_t(tmp); +#endif return ret; } CUTLASS_HOST_DEVICE bfloat16_t operator--(bfloat16_t & lhs, int) { bfloat16_t ret(lhs); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); tmp--; lhs = bfloat16_t(tmp); +#endif return ret; } diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 3d140eaa84..a0fa22b6bb 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -172,6 +172,7 @@ struct ClusterLauncher { "And ClusterDims = " "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + cutlass::arch::synclog_setup(); cudaError_t status = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); Return_Status(status); #else diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index 84fd37b47e..78862b0a09 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -41,8 +41,8 @@ #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" +#include "cutlass/conv/detail.hpp" #include "cutlass/conv/convolution.h" -#include "cutlass/conv/convnd_problem_shape.hpp" #include "cutlass/conv/dispatch_policy.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/util/packed_stride.hpp" @@ -103,6 +103,8 @@ struct CollectiveConv< using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename cutlass::PipelineState; + + using ProblemShape = ConvProblemShape; // TODO: move pipeline mode tiling into the collective setup phase instead static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); @@ -143,7 +145,7 @@ struct CollectiveConv< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; @@ -162,8 +164,6 @@ struct CollectiveConv< // Host side kernel arguments struct Arguments { - using ProblemShape = ConvProblemShape; - ProblemShape problem_shape{}; ElementA const* ptr_A{nullptr}; ElementB const* ptr_B{nullptr}; }; @@ -175,7 +175,7 @@ struct CollectiveConv< // Get tma_load_a instantce. template static constexpr auto - get_tma_load_a_instance(TensorA const& tensor_a, typename Arguments::ProblemShape const& problem_shape) { + get_tma_load_a_instance(TensorA const& tensor_a, ProblemShape const& problem_shape) { if constexpr (is_im2col_A) { // compute the upper and lower corners based on the conv padding auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); @@ -218,7 +218,7 @@ struct CollectiveConv< // Get tma_load_b instantce. template static constexpr auto - get_tma_load_b_instance(TensorB const& tensor_b, typename Arguments::ProblemShape const& problem_shape) { + get_tma_load_b_instance(TensorB const& tensor_b, ProblemShape const& problem_shape) { // TMA im2col mode for tensor B in wgrad kernel. if constexpr (is_im2col_B) { // compute the upper and lower corners based on the conv padding @@ -250,24 +250,25 @@ struct CollectiveConv< } } +public: + + // Performs im2col transformations on the input of type ConvProblemShape static constexpr auto - get_problem_shape_MNKL(typename Arguments::ProblemShape const& problem_shape) { + get_problem_shape_MNKL(ProblemShape const& problem_shape) { + if constexpr (is_im2col_A || is_im2col_B) { // transformation + im2col linearization - return problem_shape.get_linearized_problem_shape_MNKL(); + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); } else { // transformation - return problem_shape.get_transformed_problem_shape_MNKL(); + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); } } -public: - // Device side kernel params struct Params { - using _Submode = decltype(take<0,NumTensorDimensions-1>(typename Arguments::ProblemShape::TensorExtent{})); - using ProblemShape = decltype(get_problem_shape_MNKL(typename Arguments::ProblemShape{})); + using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{})); // Assumption: StrideA is congruent with Problem_MK // Select TMA load type according to convolution operator. @@ -294,7 +295,6 @@ struct CollectiveConv< // Members TMA_A tma_load_a; TMA_B tma_load_b; - ProblemShape problem_shape; uint32_t tma_transaction_bytes = TmaTransactionBytes; }; @@ -304,19 +304,19 @@ struct CollectiveConv< // Lowers the host side user facing arguments to the kernel facing lauch params static constexpr Params - to_underlying_arguments(Arguments const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple // tma desc creation depends on the original untransformed domain. // A extents. - auto shape_A_orig = args.problem_shape.get_shape_A(); + auto shape_A_orig = problem_shape.get_shape_A(); // B extents. - auto shape_B_orig = args.problem_shape.get_shape_B(); + auto shape_B_orig = problem_shape.get_shape_B(); // Fill inferred cute strides from flat stride arrays - auto dA = make_cute_packed_stride(StrideA{}, args.problem_shape.stride_A, ConvOp); - auto dB = make_cute_packed_stride(StrideB{}, args.problem_shape.stride_B, ConvOp); + auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp); + auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp); auto ptr_A = reinterpret_cast(args.ptr_A); auto ptr_B = reinterpret_cast(args.ptr_B); @@ -324,20 +324,17 @@ struct CollectiveConv< Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); - auto tma_load_a = get_tma_load_a_instance(tensor_a, args.problem_shape); - auto tma_load_b = get_tma_load_b_instance(tensor_b, args.problem_shape); - - auto problem_shape_mnkl = get_problem_shape_MNKL(args.problem_shape); + auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape); + auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape); return { tma_load_a, tma_load_b, - problem_shape_mnkl, TmaTransactionBytes }; } - - template + + template static bool can_implement( ProblemShape const& problem_shape, @@ -345,14 +342,14 @@ struct CollectiveConv< // Activation and Filter channel mode extents much match bool implementable = true; // channel mode is major - implementable &= args.problem_shape.stride_A[NumTensorDimensions-1] == 1; - implementable &= args.problem_shape.stride_B[NumTensorDimensions-1] == 1; + implementable &= problem_shape.stride_A[NumTensorDimensions-1] == 1; + implementable &= problem_shape.stride_B[NumTensorDimensions-1] == 1; constexpr int tma_alignment_bits = 128; // A extents. - auto shape_A_orig = args.problem_shape.get_shape_A(); + auto shape_A_orig = problem_shape.get_shape_A(); // B extents. - auto shape_B_orig = args.problem_shape.get_shape_B(); + auto shape_B_orig = problem_shape.get_shape_B(); constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(shape_A_orig, StrideA{}); constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; @@ -375,61 +372,6 @@ struct CollectiveConv< return false; } - if (is_im2col_A || is_im2col_B) { - // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] - constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); - auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); - } - auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); - for (int i = 0; i < problem_shape.RankS; ++i) { - implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); - return false; - } - } - - // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) - if constexpr (ConvOp == conv::Operator::kWgrad) { - - const auto & input_shape = problem_shape.shape_A; - const auto & input_stride = problem_shape.stride_A; - - implementable &= input_stride[ProblemShape::RankT - 1] == 1; - int input_shape_size = 1; - for (int i = ProblemShape::RankT - 2; i >= 0; --i) { - input_shape_size *= input_shape[i + 1]; - implementable &= input_stride[i] == input_shape_size; - } - - const auto & output_shape = problem_shape.shape_C; - const auto & output_stride = problem_shape.stride_C; - - implementable &= output_stride[ProblemShape::RankT - 1] == 1; - int output_shape_size = 1; - for (int i = ProblemShape::RankT - 2; i >= 0; --i) { - output_shape_size *= output_shape[i + 1]; - implementable &= output_stride[i] == output_shape_size; - } - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); - return false; - } - } - - // Conv kernels only support cross correlation mode currently. - implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); - return false; - } - if (problem_shape.groups > 1) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); return false; @@ -445,24 +387,53 @@ struct CollectiveConv< cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); } + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k) + /// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k) + /// The rest of the tensors can be specified as needed by this collective. + /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with + /// StrideA and StrideB set up for TMA + template + CUTLASS_DEVICE auto + load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){ + //load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mk = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K)); // (m,k) + Tensor mB_nk = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K)); // (n,k) + + // Make tiled views, defer the slice + Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) + Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + return cute::make_tuple(gA_mk, gB_nk); + } + /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective template < - class TensorA, class TMA_LOAD_A, - class TensorB, class TMA_LOAD_B, - class KTileIterator + class TensorA, class TensorB, + class KTileIterator, class BlockCoord > CUTLASS_DEVICE void - load(MainloopPipeline pipeline, - PipelineState smem_pipe_producer_state, - TensorA const& gA, TMA_LOAD_A& tma_load_a, - TensorB const& gB, TMA_LOAD_B& tma_load_b, - KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, - uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) { - int lane_predicate = cute::elect_one_sync(); + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_producer_state, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -470,11 +441,19 @@ struct CollectiveConv< // // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + auto [gA_mk, gB_nk] = load_inputs; + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) + Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -518,8 +497,9 @@ struct CollectiveConv< BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_producer_state); int write_stage = smem_pipe_producer_state.index(); - copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); ++k_tile_iter; // Advance smem_pipe_producer_state diff --git a/include/cutlass/conv/convnd_problem_shape.hpp b/include/cutlass/conv/convnd_problem_shape.hpp index 0172120538..ffcc547fbd 100644 --- a/include/cutlass/conv/convnd_problem_shape.hpp +++ b/include/cutlass/conv/convnd_problem_shape.hpp @@ -43,6 +43,7 @@ #include #endif + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::conv { @@ -54,15 +55,17 @@ namespace cutlass::conv { // Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types. template < conv::Operator ConvOp_, - int NumSpatialDimensions + int NumSpatialDimensions_ > struct ConvProblemShape { // // Alias types for members // - static constexpr int RankS = NumSpatialDimensions; - static constexpr int RankT = NumSpatialDimensions + 2; + + static constexpr int RankS = NumSpatialDimensions_; + static constexpr int RankT = NumSpatialDimensions_ + 2; static constexpr conv::Operator ConvOp = ConvOp_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; using SpatialExtent = cute::array; using TensorExtent = cute::array; using TensorStride = cute::array; @@ -352,71 +355,6 @@ struct ConvProblemShape { } } - // Get problem shape MNKL according to following table: - // | | Fprop | Dgrad | Wgrad | - // | ---- | --------- | -------- | -------- | - // | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) | - // | Shape_N | (K) | (C) | (C,S,R,T) | - // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | - // | Shape_L | _1 | (V,U,O) | _1 | - CUTLASS_HOST_DEVICE - constexpr auto - get_transformed_problem_shape_MNKL() const { - using cute::insert; - using cute::make_shape; - using cute::reverse; - using cute::take; - - if constexpr (ConvOp == conv::Operator::kWgrad) { - auto M_xformed = shape_C[0]; - auto N_xformed = reverse(take<1, RankT>(shape_C)); - auto K_xformed = reverse(take<0, RankT - 1>(shape_A)); - auto L_xformed = cute::Int<1>{}; - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } - else if constexpr (ConvOp == conv::Operator::kFprop){ - auto M_xformed = reverse(take<0, RankT - 1>(shape_C)); - auto N_xformed = shape_C[RankT - 1]; - auto K_xformed = reverse(take<1, RankT>(shape_B)); - auto L_xformed = cute::Int<1>{}; - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } - else if constexpr (ConvOp == conv::Operator::kDgrad) { - auto L_xformed = reverse(traversal_stride); // (V,U,O) - auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(shape_C)), L_xformed); - auto N_xformed = shape_C[RankT - 1]; - // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] - auto K_xformed = insert<0>( - (reverse(take<1,RankT - 1>(shape_B))), - shape_B[0]); - - return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); - } - } - - // Assuming im2col linearization - // Get problem shape MNKL according to following table: - // | | Fprop | Dgrad | Wgrad | - // | ---- | --------- | -------- | -------- | - // | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) | - // | Shape_N | (K) | (C) | (C,S,R,T) | - // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) | - // | Shape_L | _1 | (V*U*O) | _1 | - CUTLASS_HOST_DEVICE - constexpr auto - get_linearized_problem_shape_MNKL() const { - auto [M, N, K, L] = get_transformed_problem_shape_MNKL(); - - if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { - return cute::make_shape(cute::product(M), N, K, cute::product(L)); - } - else if constexpr (ConvOp == conv::Operator::kWgrad) { - return cute::make_shape(M, N, cute::product(K), L); - } - } - // Get A extents. // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) @@ -578,9 +516,7 @@ struct ConvProblemShape { // calculate n,z,p,q,k. // a helper lambda to compute a single spatial extent of the nzpqk tensor auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) { - auto tmp = act_ext + pad_total - ((filter_ext -1) * dilation + 1); - CUTLASS_ASSERT(tmp % tstride == 0); - return 1 + tmp / tstride; + return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride; }; shape_xformed_act[0] = shape_act[0]; // Activation N extent diff --git a/include/cutlass/conv/detail.hpp b/include/cutlass/conv/detail.hpp new file mode 100644 index 0000000000..3e4173569c --- /dev/null +++ b/include/cutlass/conv/detail.hpp @@ -0,0 +1,137 @@ + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + // Helper function to get the problem shape +template +auto get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::true_type) { + return T::get_problem_shape_MNKL(problem_shape); +} + +template +ProblemShape get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::false_type) { + return problem_shape; +} + +// Get problem shape MNKL according to following table: +// | | Fprop | Dgrad | Wgrad | +// | ---- | --------- | -------- | -------- | +// | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) | +// | Shape_N | (K) | (C) | (C,S,R,T) | +// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | +// | Shape_L | _1 | (V,U,O) | _1 | + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_transformed_problem_shape_MNKL(ProblemShape const& problem_shape) { + return problem_shape; +} + + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_transformed_problem_shape_MNKL(ConvProblemShape const& problem_shape) { + using cute::insert; + using cute::make_shape; + using cute::reverse; + using cute::take; + + constexpr int RankT = SpatialDim + 2; + + if constexpr (ConvOp == conv::Operator::kWgrad) { + auto M_xformed = problem_shape.shape_C[0]; + auto N_xformed = reverse(take<1, RankT>(problem_shape.shape_C)); + auto K_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_A)); + auto L_xformed = cute::Int<1>{}; + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } + else if constexpr (ConvOp == conv::Operator::kFprop){ + auto M_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_C)); + auto N_xformed = problem_shape.shape_C[RankT - 1]; + auto K_xformed = reverse(take<1, RankT>(problem_shape.shape_B)); + auto L_xformed = cute::Int<1>{}; + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + auto L_xformed = reverse(problem_shape.traversal_stride); // (V,U,O) + auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(problem_shape.shape_C)), L_xformed); + auto N_xformed = problem_shape.shape_C[RankT - 1]; + // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] + auto K_xformed = insert<0>( + (reverse(take<1,RankT - 1>(problem_shape.shape_B))), + problem_shape.shape_B[0]); + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } +} + +// Assuming im2col linearization +// Get problem shape MNKL according to following table: +// | | Fprop | Dgrad | Wgrad | +// | ---- | --------- | -------- | -------- | +// | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) | +// | Shape_N | (K) | (C) | (C,S,R,T) | +// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) | +// | Shape_L | _1 | (V*U*O) | _1 | +template +CUTLASS_HOST_DEVICE +constexpr auto +get_linearized_problem_shape_MNKL(ConvProblemShape const& problem_shape) { + + auto [M, N, K, L] = get_transformed_problem_shape_MNKL(problem_shape); + + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + return cute::make_shape(cute::product(M), N, K, cute::product(L)); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return cute::make_shape(M, N, cute::product(K), L); + } + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::detail + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 0472b898c2..193f8d8854 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -61,7 +61,7 @@ template class ConvUniversalAdapter { public: - using ConvKernel = ConvKernel_; + using ConvKernel = GetUnderlyingKernel_t; using TileShape = typename ConvKernel::TileShape; using ElementA = typename ConvKernel::ElementA; using ElementB = typename ConvKernel::ElementB; @@ -76,7 +76,7 @@ class ConvUniversalAdapter // Tease out meta-information about the conv algorithm static constexpr conv::Operator kConvolutionalOperator = DispatchPolicy::ConvOp; - static constexpr int NumSpatialDimensions = ConvKernel::NumSpatialDimensions; + static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! using OperatorClass = cute::conditional_t< @@ -121,13 +121,13 @@ class ConvUniversalAdapter static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; // Inspect TiledCopy for A and B to compute the alignment size - static int constexpr kAlignmentA = detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); - static int constexpr kAlignmentB = detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); - static int constexpr kAlignmentC = detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); - static int constexpr kAlignmentD = detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; @@ -297,8 +297,9 @@ class ConvUniversalAdapter Status launch_result; // Use extended launch API only for mainloops that use it if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) { - constexpr bool is_static_1x1x1 = cute::is_static_v and - cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); diff --git a/include/cutlass/conv/device/direct_convolution.h b/include/cutlass/conv/device/direct_convolution.h index 84953d8036..43ab94b5fc 100644 --- a/include/cutlass/conv/device/direct_convolution.h +++ b/include/cutlass/conv/device/direct_convolution.h @@ -211,6 +211,7 @@ class DirectConvolution { dim3 grid = ReorderKernel::get_grid_shape(params_); dim3 block = ReorderKernel::get_block_shape(); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); } @@ -229,6 +230,7 @@ class DirectConvolution { if (status != cudaSuccess) return Status::kErrorInternal; + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 62c7e8715d..a1cb06e98f 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -53,7 +53,7 @@ template class ImplicitGemmConvolution { public: - using UnderlyingKernel = ImplicitGemmKernel_; + using UnderlyingKernel = GetUnderlyingKernel_t; using ElementA = typename UnderlyingKernel::ElementA; using LayoutA = typename UnderlyingKernel::LayoutA; @@ -103,7 +103,6 @@ class ImplicitGemmConvolution { /// Determines whether the Implicit GEMM can execute the given problem. static Status can_implement(Arguments const &args) { - // dispatch to iterators Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); if (Status::kSuccess != status) { @@ -164,9 +163,8 @@ class ImplicitGemmConvolution { // check for unsupported problem sizes for strided dgrad / deconv implementation if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) && kStrideSupport == conv::StrideSupport::kStrided) { - // split-k (serial or parallel) is not supported for strided dgrad / deconv - if(args.problem_size.split_k_slices > 1) { + if(args.problem_size.split_k_slices > 1 && (args.problem_size.stride().at(args.problem_size.stride().max_dim_index()) > 1)) { return Status::kErrorNotSupported; } @@ -291,7 +289,7 @@ class ImplicitGemmConvolution { } /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { ThreadblockSwizzle threadblock_swizzle; @@ -311,7 +309,7 @@ class ImplicitGemmConvolution { void* kernel_params[] = {¶ms_}; launch_result = cuda_adapter->launch( - grid, dim3(1,1,1), block, smem_size, stream, kernel_params, 0 + grid, dim3(1,1,1), block, smem_size, stream, kernel_params, kernel_index ); } else { @@ -319,6 +317,7 @@ class ImplicitGemmConvolution { } } else { + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); } @@ -333,20 +332,20 @@ class ImplicitGemmConvolution { } /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { - return run(stream, cuda_adapter); + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + return run(stream, cuda_adapter, kernel_index); } /// Runs the kernel using initialized state. Status operator()( Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { Status status = initialize(args, workspace, stream, cuda_adapter); if (status == Status::kSuccess) { - status = run(stream, cuda_adapter); + status = run(stream, cuda_adapter, kernel_index); } return status; diff --git a/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h index 1eb0d5600e..265156cc5b 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h @@ -231,6 +231,7 @@ class ImplicitGemmConvolutionFusion { int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/conv/dispatch_policy.hpp b/include/cutlass/conv/dispatch_policy.hpp index 039f4539c4..b8b5eb2bff 100644 --- a/include/cutlass/conv/dispatch_policy.hpp +++ b/include/cutlass/conv/dispatch_policy.hpp @@ -37,6 +37,8 @@ #include "cute/layout.hpp" #include "cute/numeric/integral_constant.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + ////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////// @@ -48,7 +50,7 @@ namespace cutlass::conv { // // Policies for categorical dispatch of mainloop against kernel grid schedules // -struct KernelImplicitTmaWarpSpecializedSm90 { }; +struct KernelImplicitTmaWarpSpecializedSm90 : cutlass::gemm::KernelTmaWarpSpecialized { }; struct KernelImplicitTmaWarpSpecializedSm90Cooperative { }; struct KernelImplicitTmaWarpSpecializedSm90Pingpong { }; @@ -84,3 +86,5 @@ struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv + +////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/conv_universal.hpp b/include/cutlass/conv/kernel/conv_universal.hpp index 9d98dc9d96..23ccea2f8f 100644 --- a/include/cutlass/conv/kernel/conv_universal.hpp +++ b/include/cutlass/conv/kernel/conv_universal.hpp @@ -30,6 +30,7 @@ **************************************************************************************************/ #pragma once +#include "cutlass/conv/convnd_problem_shape.hpp" #include "cutlass/detail/dependent_false.hpp" //////////////////////////////////////////////////////////////////////////////// @@ -43,6 +44,7 @@ namespace cutlass::conv::kernel { * a composition of a collective mainloop and a collective epilogue. **/ template < + class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileSchedulerTag_ = void, diff --git a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp index 95780bf84e..657ac6b3ec 100644 --- a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -37,9 +37,12 @@ #include "cute/tensor.hpp" #include "cute/arch/cluster_sm90.hpp" +#include "cutlass/conv/detail.hpp" #include "cutlass/conv/convolution.h" #include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/kernel/tile_scheduler.hpp" /////////////////////////////////////////////////////////////////////////////// @@ -49,365 +52,25 @@ namespace cutlass::conv::kernel { /////////////////////////////////////////////////////////////////////////////// template < + class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, - class TileSchedulerTag_ + class TileScheduler_ > class ConvUniversal< + ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, - TileSchedulerTag_, - cute::enable_if_t>> -{ -public: - // - // Type Aliases - // - - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; - static_assert(ArchTag::kMinComputeCapability >= 90); - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - using TileSchedulerTag = TileSchedulerTag_; - static_assert(cute::is_void_v, - "TMA warp-specialized kernel does not support specializing the tile scheduler."); - using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< - TileSchedulerTag, ArchTag, TileShape, ClusterShape>::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - - // Kernel level shared memory storage - struct SharedStorage { - union TensorStorage { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16> { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = 1; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - // Host facing host arguments - struct Arguments { - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel device entry point API - struct Params { - MainloopParams mainloop; - EpilogueParams epilogue; - }; - - // - // Methods - // - - // Map user facing arguments to device facing params - static Params - to_underlying_arguments(Arguments const& args, void* workspace) { - (void) workspace; - auto mainloop_params = CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace); - auto problem_shape_MNKL = args.mainloop.problem_shape.get_transformed_problem_shape_MNKL(); - - return { - mainloop_params, - CollectiveEpilogue::to_underlying_arguments(problem_shape_MNKL, args.epilogue, workspace) - }; - } - - // Given arguemnts, returns true if the kernel can successfully compute upon them. False otherwise. - static bool - can_implement(Arguments const& args) { - bool implementable = true; - implementable &= CollectiveMainloop::can_implement(args.mainloop.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.mainloop.problem_shape.get_transformed_problem_shape_MNKL(), args.epilogue); - return implementable; - } - - static size_t - get_workspace_size(Arguments const& args) { - return 0; - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 - get_grid_shape(Params const& params) { - return cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl( - params.mainloop.problem_shape, TileShape{}, ClusterShape{}); - } - - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - using X = Underscore; - - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif - - enum class WarpGroupRole { - Producer = 0, - Consumer = 1, - }; - - enum class ProducerWarpRole { - MainloopEpilogue = 0, - Warp1 = 1, - Warp2 = 2, - Warp3 = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer) { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { - epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; - } - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [&] () { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - return [] () { cute::cluster_wait(); }; - } - else { - __syncthreads(); - return [] () {}; // do nothing - } - } (); - - // Separate out problem shape for convenience - auto problem_shape_MNKL = append<4>(params.mainloop.problem_shape, _1{}); - auto [M, N, K, L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mk = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M, K)); - Tensor mB_nk = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N, K)); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - auto cta_tile_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - TiledMma tiled_mma; - - // Make tiled views, defer the slice - Tensor gA_mk = local_tile(mA_mk, cta_tile_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) - Tensor gB_nk = local_tile(mB_nk, cta_tile_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) - - // Compute m_coord, n_coord, and l_coord with their post-tiled shapes - auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mk)); - auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nk), compact_col_major(shape<2>(gB_nk))); - - // The output shape M is linearized so the output coord M here should also be linearized. - auto output_tile_coord = make_coord(int(blockIdx.x), n_coord, _, Int<0>{}); - - // Slice with m_coord and n_coord - Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) - Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) - - // Get pipeline iterators and increments from tensor shapes - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - auto k_tile_count = size<2>(gA); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; - - // Wait for all thread blocks in Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) { - if (producer_warp_role == ProducerWarpRole::MainloopEpilogue) { - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, - k_tile_iter, k_tile_count, - lane_idx, - block_rank_in_cluster, - shared_storage.tensors.mainloop - ); - // Update starting mainloop pipeline state for the pipeline drain - mainloop_pipe_producer_state.advance(k_tile_count); - // Make sure mainloop consumer has been waited upon before issuing epilogue load - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if (collective_epilogue.is_producer_load_needed()) { - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - cta_tile_shape, - output_tile_coord, - tiled_mma, - lane_idx, - shared_storage.tensors.epilogue - ); - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } - } - } - else if (warp_group_role == WarpGroupRole::Consumer) { - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(cta_tile_shape)); // (MMA,MMA_M,MMA_N) - - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - k_tile_count, - thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - k_tile_count - ); - - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, - epi_load_pipe_consumer_state, - epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - cta_tile_shape, - output_tile_coord, - accumulators, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue - ); - - collective_epilogue.store_tail( - epi_load_pipeline, - epi_load_pipe_consumer_state_next, - epi_store_pipeline, - epi_store_pipe_producer_state_next - ); - } - } -}; - + TileScheduler_, + cute::enable_if_t> +> : public cutlass::gemm::kernel::GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_ +> +{}; /////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv::kernel + diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index f9ff723ce1..1c8f56a652 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -82,6 +82,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// + #if !defined(__CUDACC_RTC__) #include @@ -152,6 +153,7 @@ CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000); #endif // !defined(__CUDACC_RTC__) + ///////////////////////////////////////////////////////////////////////////////////////////////// /// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index f396528307..e12616a201 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -35,6 +35,7 @@ #pragma once +#include "cutlass/arch/synclog.hpp" #include "cutlass/detail/helper_macros.hpp" //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/collective.hpp b/include/cutlass/detail/collective.hpp index d3c4c04b74..a4b288e7c9 100644 --- a/include/cutlass/detail/collective.hpp +++ b/include/cutlass/detail/collective.hpp @@ -31,7 +31,6 @@ #pragma once #include "cute/container/tuple.hpp" - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 429e5c2f06..216ba40285 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -30,13 +30,17 @@ **************************************************************************************************/ #pragma once +#include "cute/layout.hpp" +#include "cute/pointer_sparse.hpp" // cute::is_sparse +#include "cute/swizzle.hpp" // cute::Swizzle +#include "cute/swizzle_layout.hpp" // cute::detail::get_swizzle_portion +#include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90_tma.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/layout/tensor.h" #include "cutlass/numeric_types.h" +#include "cutlass/detail/collective.hpp" -#include "cute/layout.hpp" -#include "cute/util/type_traits.hpp" -#include "cute/arch/copy_sm90_tma.hpp" //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::detail { @@ -199,9 +203,16 @@ template constexpr auto stride_to_layout_tag_A() { + using InternalStrideA = cute::remove_pointer_t; if constexpr (is_major<0, StrideA>()) { // M major return layout::ColumnMajor{}; } + // Specialize for sparse layout + else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && + cute::rank(cute::get<1>(InternalStrideA{})) == 2 && + cute::is_same_v(InternalStrideA{}))>>) { + return layout::ColumnMajor{}; + } else { // K major return layout::RowMajor{}; } @@ -309,6 +320,10 @@ get_alignment_count_from_gmem_tiled_copy() { else { // For TMA tiled copies, we know the alignment has to be 128 bits if constexpr (is_tma_copy_engine()) { + // For sparse MMA, alignment in logical elements is increased by sparsity factor + if constexpr (cute::is_sparse_v) { + return 128 / sizeof_bits::value * ElementMma::sparsity; + } return 128 / sizeof_bits::value; } else { diff --git a/include/cutlass/detail/mma.hpp b/include/cutlass/detail/mma.hpp index 058f5fd3ea..0e491b9c40 100644 --- a/include/cutlass/detail/mma.hpp +++ b/include/cutlass/detail/mma.hpp @@ -42,6 +42,11 @@ namespace cutlass::detail { template struct IsSparseTensorOp : cute::false_type { }; +// TiledMma for sparse must have ValTypeE +template +struct IsSparseTensorOp> + : cute::true_type { }; + // The following metafunction is used to extract the OperatorClass from a cutlass 3.x kernel. template struct get_operator_class { diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index ba875a757a..7af5d96cf6 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -56,6 +56,13 @@ namespace cutlass { +template struct Type2Type { using type=T; }; +// using the simple type to replace the complex type to reduce this symbol size +template struct GetUnderlyingKernel : public Type2Type {}; +template class Wrapper > struct GetUnderlyingKernel> : public Wrapper {}; +template using GetUnderlyingKernel_t = typename GetUnderlyingKernel::type; + + //////////////////////////////////////////////////////////////////////////////// /// Generic CUTLASS kernel template. @@ -71,6 +78,7 @@ void Kernel(typename Operator::Params params) { Operator op; op(params, *shared_storage); + cutlass::arch::synclog_print(); } @@ -85,6 +93,8 @@ void Kernel2(typename Operator::Params params) { reinterpret_cast(SharedStorageBase); Operator::invoke(params, *shared_storage); + cutlass::arch::synclog_print(); + } @@ -107,6 +117,8 @@ void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) extern __shared__ char smem[]; Operator op; op(params, smem); + cutlass::arch::synclog_print(); + } //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 90a600028c..759591b5dc 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -258,7 +258,7 @@ struct Sm90TmaBuilderImpl { using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; - + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; @@ -273,6 +273,9 @@ struct Sm90TmaBuilderImpl { // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; + // Get register to register tiled copy that happen before shared memory store. + // Apply void as no register transform op needed currently. + using CopyOpR2R = void; // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination @@ -300,7 +303,8 @@ struct Sm90TmaBuilderImpl { CopyOpS2G, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm90_get_smem_store_op_for_accumulator()), - CopyAtomC + CopyAtomC, + CopyOpR2R >; }; @@ -386,6 +390,7 @@ struct AuxStoreDescriptor { // No-smem builder template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -402,7 +407,7 @@ template < > struct CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -452,6 +457,7 @@ struct CollectiveBuilder< // Tma warp-specialized builder template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -468,7 +474,7 @@ template < > struct CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -513,6 +519,7 @@ public: // Auto builder template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -528,7 +535,7 @@ template < > struct CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -552,7 +559,7 @@ private: using EpilogueSchedule = NoSmemWarpSpecialized; using _CollectiveBuilder = CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -574,6 +581,7 @@ public: // DEPRECATED Tma warp-specialized builder for elementwise fusion template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -591,7 +599,7 @@ template < struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -618,7 +626,7 @@ public: using CollectiveOp = typename CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -637,6 +645,7 @@ public: // DEPRECATED Tma warp-specialized builder for bias + elementwise fusion template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -654,7 +663,7 @@ template < struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltAct or fusion::LinCombPerRowBiasEltActAux instead")]] CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, @@ -714,6 +723,9 @@ private: // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; + // Get register to register tiled copy that happen before shared memory store. + // Apply void as no register transform op needed. + using CopyOpR2R = void; public: using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise< @@ -733,7 +745,8 @@ public: SM90_TMA_STORE, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm90_get_smem_store_op_for_accumulator()), - CopyAtomC + CopyAtomC, + CopyOpR2R >; }; @@ -741,6 +754,7 @@ public: // since swapping NNN kernels input matrix and transposing its output at the same time then // we can get TTN kernel. template < + class OpClass, class TileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, @@ -756,7 +770,7 @@ template < > struct CollectiveBuilder< arch::Sm90, - arch::OpClassTensorOp, + OpClass, TileShape_MNK, ClusterShape_MNK, EpilogueTileType, diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index a14696b2f8..d54cd0a8f7 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -30,6 +30,9 @@ **************************************************************************************************/ #pragma once +#include // cute::DefaultCopy +#include // cute::is_base_of_v + #include "cutlass/detail/dependent_false.hpp" #include "cutlass/epilogue/fusion/callbacks.hpp" @@ -100,7 +103,7 @@ struct CallbacksBuilder< TileShape_MNK, EpilogueTile_MN, ElementAccumulator, - cute::enable_if_t> + cute::enable_if_t> > { using Callbacks = FusionCallbacks; }; diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index f8179b0a0e..8fb1a9588b 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -53,11 +53,19 @@ class CollectiveEpilogue { ///////////////////////////////////////////////////////////////////////////////////////////////// #include "detail.hpp" + +// +// Gemm +// #include "default_epilogue.hpp" #include "default_epilogue_array.hpp" #include "epilogue_tensor_broadcast.hpp" #include "sm70_epilogue_vectorized.hpp" +#include "sm70_epilogue_vectorized_array.hpp" #include "sm90_epilogue_tma_warpspecialized.hpp" #include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" #include "sm90_epilogue_array_tma_warpspecialized.hpp" +// +// Conv +// ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index b96b13fecc..a6e13bc7e9 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -199,6 +199,14 @@ struct IsThreadEpilogueOpWithActivation +struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {}; + +template +struct IsThreadEpilogueOpWithElementwiseArguments< + ThreadEpilogueOp, + cute::void_t> : cute::true_type {}; + // Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels template class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { @@ -430,7 +438,8 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { // Dummy methods to perform different parts of TMA/Tensormap modifications - template + template CUTLASS_DEVICE void tensormaps_perform_update( diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index 69170f75ea..a8083dab1d 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -46,6 +46,25 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + class StrideC, + class StrideD, + class ThreadEpilogueOp, + class SmemLayout, + class CopyAtomR2S, + class TiledCopyS2R, + class CopyAtomR2G, + class EpilogueScheduleType = EpilogueSimtVectorized, + class Enable = void +> +class Epilogue { + static_assert(cute::is_same_v || + cute::is_same_v, + "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Epilogue Vectorized /// Applies an element wise operation to all elements within the fragment /// and writes it out to destination storage. /// @@ -61,9 +80,22 @@ template < class SmemLayout_, class CopyAtomR2S_, class TiledCopyS2R_, - class CopyAtomR2G_ + class CopyAtomR2G_, + class EpilogueScheduleType_ > -class Epilogue { +class Epilogue< + StrideC_, + StrideD_, + ThreadEpilogueOp_, + SmemLayout_, + CopyAtomR2S_, + TiledCopyS2R_, + CopyAtomR2G_, + EpilogueScheduleType_, + cute::enable_if_t< + cute::is_same_v + > + > { public: // // Type Aliases @@ -78,15 +110,17 @@ class Epilogue { using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; - + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; using SmemLayout = SmemLayout_; using CopyAtomR2S = CopyAtomR2S_; using TiledCopyS2R = TiledCopyS2R_; using CopyAtomR2G = CopyAtomR2G_; - static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using GmemTiledCopyC = void; + using GmemTiledCopyD = CopyAtomR2G; - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + static constexpr bool IsEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + using StrideBias = cute::conditional_t(), Stride<_1,_0,int64_t>, Stride<_0,_1,int64_t>>; static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); @@ -96,9 +130,35 @@ class Epilogue { cute::array_aligned> smem_epilogue; }; + static constexpr bool IsActHasArgs = detail::IsThreadEpilogueOpWithElementwiseArguments::value; + // Host side epilogue arguments + template + struct ThreadEpilogueOpArguments { + ElementScalar alpha{0}; + ElementScalar beta{0}; + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias{}; + }; + + template + struct ThreadEpilogueOpArguments< + ThreadEpiOp, + cute::enable_if_t::value>> { + ElementScalar alpha{0}; + ElementScalar beta{0}; + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias{}; + typename ThreadEpiOp::ElementwiseArguments activation{}; + }; + struct Arguments { - typename ThreadEpilogueOp::Params thread{}; + ThreadEpilogueOpArguments thread{}; + using StrideBias = decltype(thread.dBias); ElementC const* ptr_C = nullptr; StrideC dC{}; ElementD* ptr_D = nullptr; @@ -106,7 +166,32 @@ class Epilogue { }; // Device side epilogue params - using Params = Arguments; + template + struct ParamsType { + typename ThreadEpiOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + StrideBias dBias{}; + }; + + template + struct ParamsType< + ThreadEpiOp, + cute::enable_if_t::value>> { + typename ThreadEpiOp::Params thread{}; + typename ThreadEpiOp::ElementwiseArguments activation{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + StrideBias dBias{}; + }; + + using Params = ParamsType; // // Methods @@ -117,8 +202,36 @@ class Epilogue { to_underlying_arguments( [[maybe_unused]] ProblemShape const& _, Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; + [[maybe_unused]] void* workspace) { + typename ThreadEpilogueOp::Params thread_op_args; + thread_op_args.alpha = args.thread.alpha; + thread_op_args.beta = args.thread.beta; + thread_op_args.alpha_ptr = args.thread.alpha_ptr; + thread_op_args.beta_ptr = args.thread.beta_ptr; + + if constexpr (IsActHasArgs) { + return { + thread_op_args, + args.thread.activation, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + args.thread.bias_ptr, + args.thread.dBias + }; + } + else { + return { + thread_op_args, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + args.thread.bias_ptr, + args.thread.dBias + }; + } } template @@ -169,8 +282,7 @@ class Epilogue { TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, - char* smem_buf) - { + char* smem_buf) { using namespace cute; using X = Underscore; @@ -192,88 +304,112 @@ class Epilogue { auto L = get<3>(problem_shape_mnkl); // Represent the full output tensor - Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), params.dBias); // (m,n,l) + + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + // Construct a tensor in SMEM that we can partition for rearranging data SharedStorage& storage = *reinterpret_cast(smem_buf); - Tensor sC = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) - // Partition sC to match the accumulator partitioning + // Partition sAcc to match the accumulator partitioning auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); - auto tC = tiled_r2s.get_thread_slice(thread_idx); - Tensor tCaC = tC.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor tCsC = tC.partition_D(sC); // ((Atom,AtomNum),PIPE_M,PIPE_N) + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Tile gD and gC by the shape of SmemLayout first - auto tile = make_shape(size<0>(sC), size<1>(sC)); + auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gBiast = flat_divide(gBias, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - // Partition sC, gC, and gD for the output + // Partition sAcc, gC, and gD for the output auto tiled_s2r = TiledCopyS2R{}; - auto tD = tiled_s2r.get_thread_slice(thread_idx); - Tensor tDsC = tD.partition_S(sC); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tDgC = tD.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - Tensor tDgD = tD.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gBias = thread_s2r.partition_D(gBiast); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) // Allocate intermediate registers on the dst tensors - Tensor tDrC = make_tensor(take<0,3>(shape(tDgC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tDrD = make_tensor(shape(tDrC)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rC = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rBias = make_tensor_like(tSR_gBias); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) // Repeat the D-partitioning for coordinates and predication - Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - Tensor tDcD = tD.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) - CUTE_STATIC_ASSERT(size<1>(tCaC) % size<3>(tDgC) == 0); // TILE_M divides MMA_M - CUTE_STATIC_ASSERT(size<2>(tCaC) % size<4>(tDgC) == 0); // TILE_N divides MMA_N - CUTE_STATIC_ASSERT(typename TiledCopyS2R::TiledNumThr{} == size<0>(typename TiledMma::AtomLayoutC_TV{})); + CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N #if 0 if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { print("aC : "); print(accumulators.layout()); print("\n"); print("gC : "); print(gC.layout()); print("\n"); print("gD : "); print(gD.layout()); print("\n"); - print("sC : "); print(sC.layout()); print("\n"); + print("gBias : "); print(gBias.layout()); print("\n"); + print("sAcc : "); print(sAcc.layout()); print("\n"); print("\n"); - print("tCsC : "); print(tCsC.layout()); print("\n"); - print("tCaC : "); print(tCaC.layout()); print("\n"); + print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); + print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); print("\n"); print("gDt : "); print(gDt.layout()); print("\n"); - print("tDsC : "); print(tDsC.layout()); print("\n"); - print("tDrC : "); print(tDrC.layout()); print("\n"); + print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); + print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); print("\n"); - print("tDrD : "); print(tDrD.layout()); print("\n"); - print("tDgC : "); print(tDgC.layout()); print("\n"); - print("tDgD : "); print(tDgD.layout()); print("\n"); + print("tSR_rC : "); print(tSR_rC.layout()); print("\n"); + print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); + print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); + print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); print("\n"); + print("gBiast : "); print(gBiast.layout()); print("\n"); + print("tSR_gBias : "); print(tSR_gBias.layout()); print("\n"); + print("tSR_rBias : "); print(tSR_rBias.layout()); print("\n"); } #endif + if constexpr (IsEpilogueBiasSupported) { + if (params.ptr_Bias) { + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + Tensor tSR_gBias_flt = filter_zeros(tSR_gBias); + Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); + Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride()); + + // Step 0. Copy Bias from GMEM to fragment + auto pred_fn = [&] (auto const&... coords) { return elem_less(tSR_cD_flt(coords...), take<0, 2>(residue_mnk)); }; + copy_if(pred_fn, tSR_gBias_flt, tSR_rBias_flt); + } + } + // For each tiling needed for SmemLayout to cover shape(gD) CUTLASS_PRAGMA_UNROLL - for (int step_m = 0; step_m < size<2>(cDt); ++step_m) - { + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { CUTLASS_PRAGMA_UNROLL - for (int step_n = 0; step_n < size<3>(cDt); ++step_n) - { + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { // Step 1. Copy to SMEM CUTLASS_PRAGMA_UNROLL - for (int pipe_m = 0; pipe_m < size<1>(tCsC); ++pipe_m) { + for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { CUTLASS_PRAGMA_UNROLL - for (int pipe_n = 0; pipe_n < size<2>(tCsC); ++pipe_n) { - int mma_m = step_m * size<1>(tCsC) + pipe_m; - int mma_n = step_n * size<2>(tCsC) + pipe_n; + for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { + int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; + int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; - copy(tiled_r2s, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); + copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); } } @@ -281,59 +417,115 @@ class Epilogue { synchronize(); // Step 3. Copy from SMEM into a fragment - copy(tiled_s2r, tDsC, tDrC); + copy(tiled_s2r, tSR_sAcc, tSR_rAcc); // Step 4. Wait for SMEM reads to complete synchronize(); - Tensor tDgDmn = tDgD(_,_,_,step_m,step_n); - Tensor tDcDmn = tDcD(_,_,_,step_m,step_n); + Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); + Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); + + if constexpr (IsEpilogueBiasSupported) { + Tensor tSR_rBiasmn = tSR_rBias(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rC(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + if constexpr (IsActHasArgs) { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i), params.activation); + } else { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i)); + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + if constexpr (IsActHasArgs) { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i), params.activation); + } else { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i)); + } + } + } - if (epilogue_op.is_source_needed()) { - // source is needed - Tensor tDgCmn = tDgC(_,_,_,step_m,step_n); CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tDgDmn); ++m) - { + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tDgDmn); ++n) - { + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { // Predication - if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && - get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) - { - // Step 5. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tDrC); ++i) { - tDrD(i,m,n) = epilogue_op(tDrC(i,m,n), tDgCmn(i,m,n)); + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // The Last Step. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } else { + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rC(i,m,n) = tSR_gCmn(i,m,n); + } } - // Step 6. Copy to GMEM - copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); } } + + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i), tSR_rC(i)); + } } - } - else { - // source is not needed, avoid load and lift compute + else { + // source is not needed, avoid load and lift compute - // Step 5. Elementwise operation with conversion - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tDrC); ++i) { - tDrD(i) = epilogue_op(tDrC(i)); + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i)); + } } CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tDgDmn); ++m) - { + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tDgDmn); ++n) - { + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { // Predication - if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && - get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) - { - // Step 6. Copy to GMEM - copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // The Last Step. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); } } } diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp new file mode 100644 index 0000000000..8a70370b21 --- /dev/null +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Ptr Array Epilogue Vectorized +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +/// +/// Ways to generalize this: +/// - CTA tile shape +/// - vectorization requirements (GMEM) +/// - vectoriz(able) transform() +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class SmemLayout_, + class CopyAtomR2S_, + class TiledCopyS2R_, + class CopyAtomR2G_, + class EpilogueScheduleType_ +> +class Epilogue< + StrideC_, + StrideD_, + ThreadEpilogueOp_, + SmemLayout_, + CopyAtomR2S_, + TiledCopyS2R_, + CopyAtomR2G_, + EpilogueScheduleType_, + cute::enable_if_t< + cute::is_same_v + > + > { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + using SmemLayout = SmemLayout_; + using CopyAtomR2S = CopyAtomR2S_; + using TiledCopyS2R = TiledCopyS2R_; + using CopyAtomR2G = CopyAtomR2G_; + + using GmemTiledCopyC = TiledCopyS2R; + using GmemTiledCopyD = TiledCopyS2R; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage + { + cute::array_aligned> smem_epilogue; + }; + + using TensorMapStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const&, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + Epilogue(Params const& params_) + : params(params_) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return true; + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // synchronizing function for smem reads/writes +#if CUDA_BARRIER_ENABLED + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; +#else + auto synchronize = [] () { __syncthreads(); }; +#endif + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + const int32_t mock_l_coord = 0; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + if (epilogue_op.is_source_needed() && params.dC == nullptr) { + // Beta value is non-zero while pointer to C is a nullptr + assert(0); + } + + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + stride_c = params.dC[l_coord]; + } + stride_d = params.dD[l_coord]; + } + else { + stride_c = params.dC; + stride_d = params.dD; + } + + // Represent the full output tensor + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) + + // Construct a tensor in SMEM that we can partition for rearranging data + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + + // Partition sAcc to match the accumulator partitioning + auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Tile gD and gC by the shape of SmemLayout first + auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); + Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + + // Partition sAcc, gC, and gD for the output + auto tiled_s2r = TiledCopyS2R{}; + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Allocate intermediate registers on the dst tensors + Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Repeat the D-partitioning for coordinates and predication + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N + +#if 0 + if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { + print("aC : "); print(accumulators.layout()); print("\n"); + print("gC : "); print(gC.layout()); print("\n"); + print("gD : "); print(gD.layout()); print("\n"); + print("sAcc : "); print(sAcc.layout()); print("\n"); + print("\n"); + print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); + print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); + print("\n"); + print("gDt : "); print(gDt.layout()); print("\n"); + print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); + print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); + print("\n"); + print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); + print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); + print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); + print("\n"); + } +#endif + + // For each tiling needed for SmemLayout to cover shape(gD) + CUTLASS_PRAGMA_UNROLL + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { + CUTLASS_PRAGMA_UNROLL + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { + // Step 1. Copy to SMEM + CUTLASS_PRAGMA_UNROLL + for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { + CUTLASS_PRAGMA_UNROLL + for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { + int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; + int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; + + copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); + } + } + + // Step 2. Wait for SMEM writes to complete + synchronize(); + + // Step 3. Copy from SMEM into a fragment + copy(tiled_s2r, tSR_sAcc, tSR_rAcc); + + // Step 4. Wait for SMEM reads to complete + synchronize(); + + Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); + Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + Tensor tSR_rCmn = make_tensor(shape(tSR_gCmn)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rCmn(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rD(i,m,n) = epilogue_op(tSR_rAcc(i,m,n), tSR_rCmn(i,m,n)); + } + // Step 7. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i)); + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // Step 6. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + } + } + } + +private: + Params params; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 87b6786721..56bdd84344 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -77,7 +77,8 @@ template < class CopyOpS2G_, class SmemLayoutAtomD_, class CopyOpR2S_, - class CopyAtomC_ + class CopyAtomC_, + class CopyOpR2R_ > class CollectiveEpilogue< Sm90PtrArrayTmaWarpSpecialized { public: // @@ -129,7 +131,7 @@ class CollectiveEpilogue< using SmemLayoutAtomD = SmemLayoutAtomD_; using CopyOpR2S = CopyOpR2S_; using CopyAtomC = CopyAtomC_; - + using CopyOpR2R = CopyOpR2R_; using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; using GmemTiledCopyC = CopyOpG2S; @@ -164,6 +166,9 @@ class CollectiveEpilogue< constexpr static bool is_im2col_C = cute::is_same_v; constexpr static bool is_im2col_D = cute::is_same_v; + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + using SmemLayoutC = decltype(tile_to_shape( SmemLayoutAtomC{}, make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), @@ -233,7 +238,7 @@ class CollectiveEpilogue< FusionStorage thread; } tensors; - struct TensorMapStorage : cute::aligned_struct<128> { + struct TensorMapStorage : cute::aligned_struct<128, _0> { cute::TmaDescriptor smem_tensormap_C; cute::array smem_tensormap_D; } tensormaps; @@ -265,7 +270,7 @@ class CollectiveEpilogue< take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{})); - + using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), @@ -333,7 +338,6 @@ class CollectiveEpilogue< take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{}); - } typename Params::TMA_D tma_store_d; @@ -369,16 +373,18 @@ class CollectiveEpilogue< template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); auto descriptors_shape = cute::make_shape(sm_count, Int{}); constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); } template static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); } @@ -408,10 +414,10 @@ class CollectiveEpilogue< constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); } - + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); } - } + } else { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); } @@ -507,9 +513,9 @@ class CollectiveEpilogue< auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; static_assert(!is_im2col_D, "Do not support im2col"); - + auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); - + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); @@ -542,12 +548,8 @@ class CollectiveEpilogue< // Predication for TMA load (one thread issues TMA load) bool issue_tma_load = cute::elect_one_sync(); - // Acquire the lock for the first stage - load_pipeline.producer_acquire(load_pipe_producer_state); - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - // Pre-loop fusion callback entry point - pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); + pld_callbacks.begin(); LoadPipelineState prior_state = load_pipe_producer_state; @@ -560,9 +562,11 @@ class CollectiveEpilogue< if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { continue; } + // Acquire the lock for this stage constexpr uint16_t mcast_mask = 0; uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); // Loop fusion callback entry point @@ -589,7 +593,7 @@ class CollectiveEpilogue< pld_callbacks.end(); if (wait_until_load_finishes && did_load) { - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = {last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()}; load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); } @@ -661,6 +665,7 @@ class CollectiveEpilogue< // Represent the full output tensor, slice to get the tile this CTA is responsible for Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) @@ -677,8 +682,27 @@ class CollectiveEpilogue< TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + TiledCopy tiled_r2s = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) @@ -733,6 +757,8 @@ class CollectiveEpilogue< CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); // Get the fusion callbacks for the consumer store warps constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ @@ -741,7 +767,7 @@ class CollectiveEpilogue< tile_coord_mnkl, tiled_mma, EpilogueTile{}, - tiled_r2s, + tiled_copy_partition_ref, cD, residue_cD, tRS_cD, @@ -774,7 +800,7 @@ class CollectiveEpilogue< // Sync requirements of smem reuse may preclude this optimization // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD int epi_m_prev = 0, epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock"); + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); // The TMA store sequence for one subtile iteration auto tma_store_fn = [&] (int epi_m, int epi_n) { @@ -886,6 +912,16 @@ class CollectiveEpilogue< cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output needs register shuffling before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + // Copy tile from register to smem if constexpr (is_destination_supported) { copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); @@ -905,6 +941,7 @@ class CollectiveEpilogue< } // for epi_m } // for epi_n + if constexpr (DelayTmaStore) { // Issue TMA stores for the last subtile tma_store_fn(epi_m_prev, epi_n_prev); @@ -991,6 +1028,7 @@ class CollectiveEpilogue< } __syncwarp(); return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index)); + } TmaDescriptor* null_tma_desc = nullptr; return cute::make_tuple(null_tma_desc); @@ -1065,7 +1103,7 @@ class CollectiveEpilogue< prob_shape, prob_stride); } - } + } else if constexpr (is_destination_supported) { ElementD const* ptr_D = nullptr; @@ -1096,8 +1134,8 @@ class CollectiveEpilogue< ProblemShape_MNKL problem_shape_mnkl, int32_t next_batch, int32_t warp_group_idx) { - if (cute::elect_one_sync()) { + if (cute::elect_one_sync()) { // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); @@ -1106,6 +1144,7 @@ class CollectiveEpilogue< tensormaps_replace_global_tensor_properties( shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx); } + } } @@ -1117,6 +1156,7 @@ class CollectiveEpilogue< cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate, int32_t warp_group_idx = 0) { + // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { if constexpr (is_source_supported) { @@ -1136,7 +1176,7 @@ class CollectiveEpilogue< if constexpr (not cute::is_void_v) { cute::tma_descriptor_fence_acquire(tensormap); } - } + } else { cute::tma_descriptor_fence_acquire(tensormap); } diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 6d173bfe23..b2fa4e3573 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -75,7 +75,8 @@ template < class CopyOpS2G_, class SmemLayoutAtomD_, class CopyOpR2S_, - class CopyAtomC_ + class CopyAtomC_, + class CopyOpR2R_ > class CollectiveEpilogue< Sm90TmaWarpSpecialized, @@ -92,7 +93,8 @@ class CollectiveEpilogue< CopyOpS2G_, SmemLayoutAtomD_, CopyOpR2S_, - CopyAtomC_ + CopyAtomC_, + CopyOpR2R_, > { public: // @@ -113,6 +115,7 @@ class CollectiveEpilogue< using SmemLayoutAtomD = SmemLayoutAtomD_; using CopyOpR2S = CopyOpR2S_; using CopyAtomC = CopyAtomC_; + using CopyOpR2R = CopyOpR2R_; using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; using GmemTiledCopyC = CopyOpG2S; @@ -147,6 +150,9 @@ class CollectiveEpilogue< constexpr static bool is_im2col_C = cute::is_same_v; constexpr static bool is_im2col_D = cute::is_same_v; + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + using SmemLayoutC = decltype(tile_to_shape( SmemLayoutAtomC{}, make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), @@ -454,12 +460,8 @@ class CollectiveEpilogue< // Predication for TMA load (one thread issues TMA load) bool issue_tma_load = cute::elect_one_sync(); - // Acquire the lock for the first stage - uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - load_pipeline.producer_acquire(load_pipe_producer_state); - // Pre-loop fusion callback entry point - pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); + pld_callbacks.begin(); CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { @@ -568,8 +570,27 @@ class CollectiveEpilogue< TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + TiledCopy tiled_r2s = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) @@ -581,7 +602,7 @@ class CollectiveEpilogue< // Allocate D registers Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); - Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) // Vectorized fragment view constexpr int FragmentSize = DispatchPolicy::FragmentSize; @@ -624,15 +645,17 @@ class CollectiveEpilogue< CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); // Get the fusion callbacks for the consumer store warps - constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + constexpr bool RefSrc = true; // Register tensors reference tiled copy src layout auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs( problem_shape_mnkl, CtaTileMNK{}, tile_coord_mnkl, tiled_mma, EpilogueTile{}, - tiled_r2s, + tiled_copy_partition_ref, cD, residue_cD, tRS_cD, @@ -647,7 +670,7 @@ class CollectiveEpilogue< using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); constexpr bool IsDirectR2S = cute::is_same_v>; using RegisterElementD = cute::conditional_t; - Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) Tensor tRS_rCompute_frg = recast>(tRS_rCompute); // Thread synchronizer for previously issued waits or fences @@ -672,7 +695,7 @@ class CollectiveEpilogue< // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD [[maybe_unused]] int epi_m_prev = 0; [[maybe_unused]] int epi_n_prev = 0; - static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock"); + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); // The TMA store sequence for one subtile iteration auto tma_store_fn = [&] (int epi_m, int epi_n) { @@ -784,6 +807,16 @@ class CollectiveEpilogue< cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output register transformation before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tRS_rD_frg); ++i) { tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp index b67c229c27..9749040081 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -62,7 +62,8 @@ template < class CopyOpS2G_, class SmemLayoutAtomD_, class CopyOpR2S_, - class CopyAtomC_ + class CopyAtomC_, + class CopyOpR2R_ > class Sm90EpilogueTmaWarpSpecializedBiasElementwise : public CollectiveEpilogue< @@ -80,7 +81,8 @@ class Sm90EpilogueTmaWarpSpecializedBiasElementwise CopyOpS2G_, SmemLayoutAtomD_, CopyOpR2S_, - CopyAtomC_ + CopyAtomC_, + CopyOpR2R_ > { private: using Impl = @@ -99,7 +101,8 @@ class Sm90EpilogueTmaWarpSpecializedBiasElementwise CopyOpS2G_, SmemLayoutAtomD_, CopyOpR2S_, - CopyAtomC_ + CopyAtomC_, + CopyOpR2R_ >; public: using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index e96f413445..f829a2ff5d 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -46,12 +46,13 @@ namespace cutlass::epilogue { ////////////////////////////////////////////////////////////////////////////// struct PtrArrayDefault {}; +struct EpilogueSimtVectorized {}; +struct EpiloguePtrArraySimtVectorized {}; struct NoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecialized {}; struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; - struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; }; diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 0bfacf34cc..3aed32710f 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -33,6 +33,7 @@ #include #include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -123,6 +124,19 @@ struct LinCombEltAct static constexpr bool IsEltActSupported = true; }; +// D = softmax(top_k(alpha * acc + beta * C)) +template< + int TopK, + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombTopKSoftmaxCol + : LinearCombination { +}; + // D = alpha * acc + beta * C + per-row bias template< @@ -131,7 +145,7 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombPerRowBias @@ -141,39 +155,39 @@ struct LinCombPerRowBias static constexpr bool IsPerRowBiasSupported = true; }; -// D = activation(alpha * acc + beta * C + per-row bias) +// D = alpha * acc + beta * C + per-column bias template< - template class ActivationFn_, class ElementOutput_, class ElementCompute_, class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > -struct LinCombPerRowBiasEltAct - : LinCombPerRowBias { - using ActivationFn = ActivationFn_; - static constexpr bool IsEltActSupported = true; +struct LinCombPerColBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerColBiasSupported = true; }; -// D = alpha * acc + beta * C + per-column bias +// D = activation(alpha * acc + beta * C + per-row bias) template< + template class ActivationFn_, class ElementOutput_, class ElementCompute_, class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > -struct LinCombPerColBias - : LinearCombination { - using ElementBias = ElementBias_; - static constexpr int AlignmentBias = AlignmentBias_; - static constexpr bool IsPerColBiasSupported = true; +struct LinCombPerRowBiasEltAct + : LinCombPerRowBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; }; // D = activation(alpha * acc + beta * C + per-row bias) @@ -187,8 +201,8 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / sizeof_bits_v, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombPerRowBiasEltActAux @@ -208,8 +222,8 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, // per-row alpha/beta - int AlignmentBias_ = 128 / sizeof_bits_v, - int AlignmentScalar_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct PerRowLinCombPerRowBiasEltAct @@ -231,7 +245,7 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct ScaledLinCombPerRowBiasEltAct @@ -261,8 +275,8 @@ template< class ElementBias_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / sizeof_bits_v, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct ScaledLinCombPerRowBiasEltActAmaxAux @@ -288,7 +302,7 @@ template< class ElementAux_ = ElementOutput_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / sizeof_bits_v, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombDeEltAct @@ -315,8 +329,8 @@ template< class ElementBias_ = ElementCompute_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / sizeof_bits_v, - int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombDeEltActDePerRowBias diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index ece5ac542e..e028846a4f 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -46,6 +46,8 @@ #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::epilogue::fusion { @@ -75,12 +77,12 @@ struct FusionCallbacks< CtaTileShapeMNK, EpilogueTile > : Sm90EVT, - Sm90ScalarBroadcast, + Sm90ScalarBroadcast>, Sm90AccFetch > { using Impl = Sm90EVT, - Sm90ScalarBroadcast, + Sm90ScalarBroadcast>, Sm90AccFetch >; using Operation = fusion::ScaledAcc; @@ -92,12 +94,15 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + // Conversion to the args expected by the visitor implementation // to_underlying_arguments will implicitly call this operator typename Impl::Arguments() const { return { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }; // end binary op @@ -120,10 +125,10 @@ template< > using Sm90LinearCombination = Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcast, // beta + Sm90ScalarBroadcast>, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc - Sm90ScalarBroadcast, // alpha + Sm90ScalarBroadcast>, // alpha Sm90AccFetch // acc > >; @@ -158,13 +163,18 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + operator typename Impl::Arguments() const { return { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }, // end binary op @@ -189,10 +199,10 @@ template< > using Sm90LinearCombinationPtrArray = Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcastPtrArray>, // beta + Sm90ScalarBroadcastPtrArray>, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc - Sm90ScalarBroadcastPtrArray>, // alpha + Sm90ScalarBroadcastPtrArray>, // alpha Sm90AccFetch // acc > >; @@ -236,8 +246,8 @@ struct FusionCallbacks< ElementScalar const* const* alpha_ptr_array = nullptr; ElementScalar const* const* beta_ptr_array = nullptr; - using StrideAlpha = Stride<_0,_0,int>; - using StrideBeta = Stride<_0,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; StrideAlpha dAlpha = {_0{}, _0{}, 0}; StrideBeta dBeta = {_0{}, _0{}, 0}; @@ -307,6 +317,11 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + using ActivationArguments = typename Sm90Compute::Arguments; ActivationArguments activation = ActivationArguments(); @@ -314,10 +329,96 @@ struct FusionCallbacks< return { // unary op: activation(beta * C + (alpha * acc)) { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C), where beta and alpha can be vectors for each batch +template< + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEltActPtrArray = + Sm90EVT, // activation(beta * C + (alpha * acc)) + Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinCombEltAct, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombEltActPtrArray { + + using Impl = Sm90LinCombEltActPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }, // end binary op @@ -347,12 +448,12 @@ template< > using Sm90LinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, // beta + Sm90ScalarBroadcast>, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, // alpha + Sm90ScalarBroadcast>, // alpha Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias > >; @@ -390,17 +491,22 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; operator typename Impl::Arguments() const { return { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add @@ -431,12 +537,12 @@ template< > using Sm90LinCombPerColBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, // beta + Sm90ScalarBroadcast>, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, // alpha + Sm90ScalarBroadcast>, // alpha Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_0,_1,int>, AlignmentBias> // bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias > >; @@ -474,17 +580,22 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - using StrideBias = Stride<_0,_1,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; operator typename Impl::Arguments() const { return { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add @@ -560,7 +671,12 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -571,10 +687,10 @@ struct FusionCallbacks< return { // unary op : activation(beta * C + (alpha * acc + bias)) { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add @@ -673,7 +789,12 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -689,10 +810,10 @@ struct FusionCallbacks< { // unary op : activation(store(beta * C + (alpha * acc + bias))) { // unary op : store(beta * C + (alpha * acc + bias)) { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add @@ -725,12 +846,12 @@ template< > using Sm90PerRowLinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, AlignmentScalar>, // beta + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // beta, dynamic scalar/vector broadcast Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, AlignmentScalar>, // alpha + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias > >; @@ -792,16 +913,16 @@ struct FusionCallbacks< >; struct Arguments { - using StrideAlpha = Stride<_1,_0,int>; - using StrideBeta = Stride<_1,_0,int>; + using StrideAlpha = Stride; + using StrideBeta = Stride; ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; - StrideAlpha dAlpha = {}; - StrideBeta dBeta = {}; + StrideAlpha dAlpha = {bool(1), _0{}, 0}; + StrideBeta dBeta = {bool(1), _0{}, 0}; - using StrideBias = Stride<_1,_0,int>; + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -864,12 +985,12 @@ template< > using Sm90ScaledLinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90ScalarBroadcast, 2>, // scale_c * beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias > >; @@ -950,7 +1071,12 @@ struct FusionCallbacks< ElementScalar const* scale_c_ptr = nullptr; ElementScalar const* scale_d_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -962,13 +1088,15 @@ struct FusionCallbacks< { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{scale_c, beta}, - {scale_c_ptr, beta_ptr} + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} }, // leaf args : (scale_c * beta) {}, // leaf args : C { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{scale_a, scale_b, alpha}, - {scale_a_ptr, scale_b_ptr, alpha_ptr} + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} }, // leaf args : (scale_a * scale_b * alpha) {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias @@ -1184,7 +1312,12 @@ struct FusionCallbacks< ElementScalar scale_aux = ElementScalar(1); ElementScalar const* scale_aux_ptr = nullptr; - using StrideBias = Stride<_1,_0,int>; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; ElementBias const* bias_ptr = nullptr; StrideBias dBias = {}; @@ -1213,13 +1346,15 @@ struct FusionCallbacks< Z_args = { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{scale_c, beta}, - {scale_c_ptr, beta_ptr} + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} }, // leaf args : (scale_c * beta) {}, // leaf args : C { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{scale_a, scale_b, alpha}, - {scale_a_ptr, scale_b_ptr, alpha_ptr} + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha ,{_0{}, _0{}, 0}, {_0{}, _0{}, 0}} }, // leaf args : (scale_a * scale_b * alpha) {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias}, // leaf args : bias @@ -1269,13 +1404,15 @@ struct FusionCallbacks< { // unary op : activation(Z) { // unary op : store(Z) { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{scale_c, beta}, - {scale_c_ptr, beta_ptr} + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} }, // leaf args : (scale_c * beta) {}, // leaf args : C { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{scale_a, scale_b, alpha}, - {scale_a_ptr, scale_b_ptr, alpha_ptr} + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}} }, // leaf args : (scale_a * scale_b * alpha) {}, // leaf args : acc {bias_ptr, ElementBias(0), dBias @@ -1377,6 +1514,11 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + using ActivationArguments = typename Sm90Compute::Arguments; ActivationArguments activation = ActivationArguments(); @@ -1388,10 +1530,10 @@ struct FusionCallbacks< return { // binary op : activation(beta * C + (alpha * acc), aux) { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }, // end binary op @@ -1430,7 +1572,7 @@ template< using Sm90LinCombDeEltActDePerRowBias = Sm90EVT, // Identity for final conversion Sm90EVT, AlignmentBias>, + ElementBias, ElementCompute, RoundStyle, Stride<_1,_0,int64_t>, AlignmentBias>, Sm90LinCombDeEltAct > @@ -1490,6 +1632,11 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + using ActivationArguments = typename Sm90Compute::Arguments; ActivationArguments activation = ActivationArguments(); @@ -1497,7 +1644,7 @@ struct FusionCallbacks< ElementAux const* aux_ptr = nullptr; StrideAux dAux = {}; - using StrideBias = Stride<_1,_0,int>; + using StrideBias = Stride<_1,_0,int64_t>; ElementBias* dbias_ptr = nullptr; StrideBias dDbias = {}; @@ -1507,10 +1654,10 @@ struct FusionCallbacks< { // unary op : reduce(activation(beta * C + (alpha * acc), aux)) { // binary op : activation(beta * C + (alpha * acc), aux) { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta {}, // leaf args : C { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha {}, // leaf args : acc {} // binary args : multiplies }, // end binary op @@ -1532,6 +1679,78 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = softmax(top_k(alpha * acc + beta * C)) +template< + int TopK, + int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombTopKSoftmaxCol = + Sm90EVT, // softmax(top_k(beta * C + (alpha * acc))) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int TopK, + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombTopKSoftmaxCol, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombTopKSoftmaxCol { + + using Impl = Sm90LinCombTopKSoftmaxCol::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombTopKSoftmaxCol; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace detail { template > struct get_element_aux { diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 2ae10a688a..131d0ba5b9 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -263,8 +263,16 @@ struct Sm90TreeVisitor< CUTLASS_DEVICE bool is_producer_load_needed() const { + auto const& scale_op = get<0>(Impl::ops); auto const& added_op = get<2>(Impl::ops); - return is_C_load_needed() || added_op.is_producer_load_needed(); + if constexpr (detail::IsScalarBroadcast::value && not is_void_v) { + return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || + is_C_load_needed() || + added_op.is_producer_load_needed(); + } + else { + return is_C_load_needed() || added_op.is_producer_load_needed(); + } } CUTLASS_DEVICE bool @@ -296,7 +304,7 @@ struct Sm90TreeVisitor< Array frg_I = convert_Z(frg_added); - if (is_C_load_needed) { + if constexpr (!is_void_v) { Array frg_scalar = get<0>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); Array frg_source = get<1>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); @@ -323,8 +331,12 @@ struct Sm90TreeVisitor< CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); + bool is_C_load_needed = this->is_C_load_needed(); + if (not is_C_load_needed) { + cute::clear(args.tCrC); + } return ConsumerStoreCallbacks( - is_C_load_needed(), std::move(callbacks_tuple)); + is_C_load_needed, std::move(callbacks_tuple)); } }; @@ -497,7 +509,18 @@ struct Sm90TreeVisitor< else { frg_compute[i] = relu(frg_compute[i]); } - frg_aux[i] = frg_compute[i] == pre_relu; + if constexpr (cute::is_same_v) { + uint32_t aux; + asm volatile("set.equ.u32.f32 %0, %1, %2;\n" : "=r"(aux) : "f"(frg_compute[i]), "f"(pre_relu)); // NaN outputs 1 in Aux + frg_aux[i] = static_cast(aux); + } else if constexpr (cute::is_same_v) { + uint32_t aux; + cutlass::half_t compute = frg_compute[i]; + asm volatile("set.equ.u32.f16 %0, %1, %2;\n" : "=r"(aux) : "h"(compute.raw()), "h"(pre_relu.raw())); // NaN outputs 1 in Aux + frg_aux[i] = static_cast(aux); + } else { + frg_aux[i] = frg_compute[i] == pre_relu; + } } static_assert(FragmentSize % 8 == 0, "Predicate vector must be byte-aligned"); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index aedacb552e..a22bed4e0d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -378,6 +378,174 @@ struct Sm90AuxLoad { } }; +template < + class Element, + class EpilogueTile, // Unused + class LayoutOrStrideMNL, + class SmemLayoutAtom, // Unused + class CopyOpS2R, // Unused + int Alignment, + bool EnableNullptr +> +struct Sm90AuxLoad< + 0, EpilogueTile, Element, LayoutOrStrideMNL, + SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr +> { + using ElementAux = Element; + using StrideMNL = cutlass::gemm::TagToStrideC_t; + + struct SharedStorage { }; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorG2R, + class RTensor, + class CTensorG2R, + class ProblemShapeMNL + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensorG2R&& tC_gAux, + RTensor&& tC_rAux, + CTensorG2R&& tC_cAux, + ProblemShapeMNL problem_shape_mnl, + Params const* params_ptr) + : tC_gAux(cute::forward(tC_gAux)), + tC_rAux(cute::forward(tC_rAux)), + tC_cAux(cute::forward(tC_cAux)), + problem_shape_mnl(problem_shape_mnl), + params_ptr(params_ptr) {} + + GTensorG2R tC_gAux; + RTensor tC_rAux; + CTensorG2R tC_cAux; + ProblemShapeMNL problem_shape_mnl; + Params const* params_ptr; + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_aux == nullptr) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + + Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); + + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); + Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); + + auto pred_fn = [&] (auto const&... coords) { + return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); + }; + + copy_if(pred_fn, tC_gAux_vec, tC_rAux_vec); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return recast>(tC_rAux)(epi_v); + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto problem_shape_mnl = make_shape(M,N,L); + + // Gmem Tensor + Tensor mAux = make_tensor( + make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux + ); + Tensor tC_gAux = sm90_partition_for_epilogue( + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + // Register Tensor + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); + + // Predication support + Tensor coordAux = make_identity_tensor(shape(mAux)); + Tensor tC_cAux = sm90_partition_for_epilogue( + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tC_gAux), + cute::move(tC_rAux), + cute::move(tC_cAux), + problem_shape_mnl, + params_ptr + ); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Broadcast Load Operations @@ -388,11 +556,12 @@ struct Sm90AuxLoad { // Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors template< class Element, - class StrideMNL = Stride<_0,_0,_0>, + class StrideMNL_ = Stride<_0,_0,_0>, int BroadcastCount = 1, template class ReductionFn = multiplies > struct Sm90ScalarBroadcast { + using StrideMNL = StrideMNL_; static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); @@ -401,7 +570,7 @@ struct Sm90ScalarBroadcast { struct Arguments { Element scalars[BroadcastCount] = {}; Element const* scalar_ptrs[BroadcastCount] = {}; - StrideMNL dScalar = {}; + StrideMNL dScalar[BroadcastCount] = {}; }; using Params = Arguments; @@ -444,7 +613,21 @@ struct Sm90ScalarBroadcast { // This must be called after update_scalar is called CUTLASS_DEVICE bool is_zero() const { - return scalar == Element(0); + if (get<2>(params_ptr->dScalar[0]) == 0) { + // Only 1 batch + return scalar == Element(0); + } + else { + // multiple batch + if (valid_scalar == false) { + // for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps. + return params_ptr->scalar_ptrs[0] == nullptr; + } + else { + // Check whether each batch is ZERO or not. + return scalar == Element(0); + } + } } CUTLASS_HOST_DEVICE @@ -454,19 +637,20 @@ struct Sm90ScalarBroadcast { Sm90ScalarBroadcast(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { // Get the scalar for non-batched broadcast - if (get<2>(params_ptr->dScalar) == 0) { + if (size<2>(params_ptr->dScalar[0]) == 0) { update_scalar(); } } Element scalar; + bool valid_scalar = false; Params const* params_ptr; template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar) != 0) { + if (size<2>(params_ptr->dScalar[0]) != 0) { auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; update_scalar(l_coord); } @@ -500,7 +684,7 @@ struct Sm90ScalarBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar) != 0) { + if (get<2>(params_ptr->dScalar[0]) != 0) { auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; update_scalar(l_coord); } @@ -511,11 +695,12 @@ struct Sm90ScalarBroadcast { private: CUTLASS_DEVICE void update_scalar(int l_coord = 0) { - int l_offset = l_coord * size<2>(params_ptr->dScalar); + valid_scalar = true; + int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); if (params_ptr->scalar_ptrs[0] != nullptr) { scalar = params_ptr->scalar_ptrs[0][l_offset]; - } + } else { // batch stride is ignored for nullptr fallback scalar = params_ptr->scalars[0]; @@ -526,8 +711,10 @@ struct Sm90ScalarBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 1; i < BroadcastCount; ++i) { if (params_ptr->scalar_ptrs[i] != nullptr) { - scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][l_offset]); - } else { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); + } + else { // batch stride is ignored for nullptr fallback scalar = reduction_fn(scalar, params_ptr->scalars[i]); } @@ -538,8 +725,8 @@ struct Sm90ScalarBroadcast { CUTLASS_DEVICE void update_scalar(cute::tuple) { // Only support multiple L-modes with fully-broadcast scalar - static_assert(cute::is_same_v>); scalar = params_ptr->scalars[0]; + valid_scalar = true; } }; @@ -706,6 +893,7 @@ struct Sm90ScalarBroadcastPtrArray { } }; + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace detail { @@ -722,32 +910,40 @@ compute_row_broadcast_stages() { template< int Stages, class CtaTileShapeMNK, - class Element, - class StrideMNL = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v, + class ElementInput, + class ElementCompute = ElementInput, + class StrideMNL_ = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct Sm90RowBroadcast { - static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + using StrideMNL = StrideMNL_; + static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining"); + + static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // row vector or scalar broadcast + static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast); struct SharedStorage { - array_aligned(CtaTileShapeMNK{})> smem; + array_aligned(CtaTileShapeMNK{})> smem; }; struct Arguments { - Element const* ptr_row = nullptr; - Element null_default = Element(0); + ElementInput const* ptr_row = nullptr; + ElementInput null_default = ElementInput(0); StrideMNL dRow = {}; }; - using Params = Arguments; + struct Params { + ElementInput const* ptr_row = nullptr; + ElementCompute null_default = ElementCompute(0); + StrideMNL dRow = {}; + }; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; + return {args.ptr_row, ElementCompute(args.null_default), args.dRow}; } template @@ -774,11 +970,22 @@ struct Sm90RowBroadcast { CUTLASS_HOST_DEVICE Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params) - , smem(const_cast(shared_storage.smem.data())) { } + : params(params), is_zero_(false), + smem(const_cast(shared_storage.smem.data())) { + auto const& [stride_M, stride_N, stride_L] = params.dRow; + // Nullptr default + if (EnableNullptr && params.ptr_row == nullptr) { + is_zero_ = params.null_default == ElementCompute(0); + } + // Dynamic non-batched scalar broadcast + else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) { + is_zero_ = params.ptr_row[0] == ElementInput(0); + } + } Params params; - Element *smem = nullptr; + bool is_zero_ = false; + ElementInput *smem = nullptr; CUTLASS_DEVICE bool is_producer_load_needed() const { @@ -792,7 +999,7 @@ struct Sm90RowBroadcast { CUTLASS_DEVICE bool is_zero() const { - return (params.ptr_row == nullptr && params.null_default == Element(0)); + return is_zero_; } template @@ -801,24 +1008,27 @@ struct Sm90RowBroadcast { return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE ConsumerStoreCallbacks( GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, - CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + Residue residue_cRow_, ThrNum thr_num_, Params const& params_) : tGS_gRow(tGS_gRow_) , tGS_sRow(tGS_sRow_) , tGS_cRow(tGS_cRow_) , tiled_G2S(tiled_g2s_) , tSR_sRow(tSR_sRow_) , tSR_rRow(tSR_rRow_) - , tCcRow(tCcRow_) - , residue_tCcRow(residue_tCcRow_) + , residue_cRow(residue_cRow_) , params(params_) - , is_nullptr(EnableNullptr && params_.ptr_row == nullptr) {} + , is_nullptr(EnableNullptr && params_.ptr_row == nullptr) { + if (is_nullptr) { + fill(tSR_rRow, params.null_default); + } + } GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) @@ -828,35 +1038,31 @@ struct Sm90RowBroadcast { SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tCcRow; // (m, n) + Residue residue_cRow; // (m, n) ThrNum thr_num; Params const& params; bool is_nullptr; CUTLASS_DEVICE void begin() { - if constexpr (EnableNullptr) { - if (params.ptr_row == nullptr) { - fill(tSR_rRow, params.null_default); - return; - } + if (is_nullptr) { + return; } auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); - Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + Tensor tGS_cRow_flt = filter_zeros(tGS_cRow, tGS_gRow.stride()); for (int i = 0; i < size(tGS_gRow_flt); ++i) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { continue; // OOB of SMEM, } - if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + if (elem_less(tGS_cRow_flt(i), residue_cRow)) { tGS_sRow_flt(i) = tGS_gRow_flt(i); } else { - tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + tGS_sRow_flt(i) = ElementInput(0); // Set to Zero when OOB so LDS can be issued without any preds. } } synchronize(); @@ -864,18 +1070,28 @@ struct Sm90RowBroadcast { CUTLASS_DEVICE void begin_loop(int epi_m, int epi_n) { - if (epi_m == 0) { // Assumes M-major subtile loop - if (is_nullptr) return; // Do not issue LDS when bias is nullptr + if (epi_m == 0 and not is_nullptr) { // Assumes M-major subtile loop Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); - Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); - copy(tSR_sRow_flt, tSR_rRow_flt); + Tensor tSR_rRow_flt = make_tensor_like(tSR_sRow_flt); + copy_aligned(tSR_sRow_flt, tSR_rRow_flt); + + constexpr int FrgSize = size(tSR_rRow_flt); + using FrgInput = Array; + using FrgCompute = Array; + using ConvertInput = NumericArrayConverter; + + Tensor tSR_rRow_input_frg = recast(coalesce(tSR_rRow_flt)); + Tensor tSR_rRow_compute_frg = recast(filter(tSR_rRow)); + ConvertInput convert_input{}; + + tSR_rRow_compute_frg(_0{}) = convert_input(tSR_rRow_input_frg(_0{})); } } template - CUTLASS_DEVICE Array + CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_row; + Array frg_row; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { @@ -896,12 +1112,30 @@ struct Sm90RowBroadcast { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + auto layout_N = [&] () { + auto shape_N = get<1>(args.problem_shape_mnkl); + if constexpr (IsDynamicBroadcast) { + auto stride_N = repeat_like(shape_N, int(0)); + if (get<1>(params.dRow) == bool(1)) { + stride_N = transform_leaf(compact_major(shape_N), + [] (auto const& stride) { return static_cast(stride); } + ); + } + return make_layout(shape_N, stride_N); + } + else { + return make_layout(shape_N); + } + }(); + + auto layout_M = make_layout(M, repeat_like(M, _0{})); + auto layout_L = make_layout(L, get<2>(params.dRow)); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_layout(layout_M,layout_N,layout_L)); Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem), make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) //// G2S: Gmem to Smem - auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, Layout< Shape<_1, ThreadCount>, Stride<_0, _1>>{}, Layout<_1>{}); @@ -910,20 +1144,18 @@ struct Sm90RowBroadcast { Tensor tGS_sRow = thr_g2s.partition_D(sRow); //// G2S: Coord - auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); - Tensor tGS_cRow = thr_g2s.partition_S(cRow); + Tensor tGS_cRow = thr_g2s.partition_S(args.cD); //// S2R: Smem to Reg Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) - return ConsumerStoreCallbacks( + return ConsumerStoreCallbacks( tGS_gRow, tGS_sRow, tGS_cRow, tiled_g2s, tSR_sRow, tSR_rRow, - args.tCcD, args.residue_cD, ThreadCount{}, params); @@ -936,31 +1168,39 @@ struct Sm90RowBroadcast { template< int Stages, class CtaTileShapeMNK, - class Element, - class StrideMNL = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v, + class ElementInput, + class ElementCompute = ElementInput, + class StrideMNL_ = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct Sm90ColBroadcast { - static_assert(Stages == 0, "Column broadcast doesn't support smem usage"); - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); + using StrideMNL = StrideMNL_; + static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining"); + + static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // Column vector or scalar broadcast + static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{} || IsDynamicBroadcast); // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem struct SharedStorage { }; struct Arguments { - Element const* ptr_col = nullptr; - Element null_default = Element(0); + ElementInput const* ptr_col = nullptr; + ElementInput null_default = ElementInput(0); StrideMNL dCol = {}; }; - using Params = Arguments; + struct Params { + ElementInput const* ptr_col = nullptr; + ElementCompute null_default = ElementCompute(0); + StrideMNL dCol = {}; + }; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; + return {args.ptr_col, ElementCompute(args.null_default), args.dCol}; } template @@ -994,7 +1234,7 @@ struct Sm90ColBroadcast { CUTLASS_DEVICE bool is_zero() const { - return (params.ptr_col == nullptr && params.null_default == Element(0)); + return is_zero_; } CUTLASS_HOST_DEVICE @@ -1002,9 +1242,20 @@ struct Sm90ColBroadcast { CUTLASS_HOST_DEVICE Sm90ColBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params) { } + : params(params), is_zero_(false) { + auto const& [stride_M, stride_N, stride_L] = params.dCol; + // Nullptr default + if (EnableNullptr && params.ptr_col == nullptr) { + is_zero_ = params.null_default == ElementCompute(0); + } + // Dynamic non-batched scalar broadcast + else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) { + is_zero_ = params.ptr_col[0] == ElementInput(0); + } + } Params params; + bool is_zero_; template CUTLASS_DEVICE auto @@ -1015,12 +1266,16 @@ struct Sm90ColBroadcast { template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, CTensor tCcCol, ThrResidue residue_tCcCol, Params const& params) - : tCgCol(cute::forward(tCgCol)), - tCrCol(cute::forward(tCrCol)), - tCcCol(tCcCol), - residue_tCcCol(residue_tCcCol), - params(params) {} + ConsumerStoreCallbacks(GTensor tCgCol_, RTensor tCrCol_, CTensor tCcCol_, ThrResidue residue_tCcCol_, Params const& params_) + : tCgCol(tCgCol_), + tCrCol(tCrCol_), + tCcCol(tCcCol_), + residue_tCcCol(residue_tCcCol_), + params(params_) { + if (EnableNullptr && params.ptr_col == nullptr) { + fill(tCrCol, params.null_default); + } + } GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) @@ -1030,23 +1285,20 @@ struct Sm90ColBroadcast { CUTLASS_DEVICE void begin() { - if constexpr (EnableNullptr) { - if (params.ptr_col == nullptr) { - fill(tCrCol, params.null_default); - return; - } + if (EnableNullptr && params.ptr_col == nullptr) { + return; } // Filter so we don't issue redundant copies over stride-0 modes // (only works if 0-strides are in same location, which is by construction) Tensor tCgCol_flt = filter_zeros(tCgCol); - Tensor tCrCol_flt = filter_zeros(tCrCol); - Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); + Tensor tCrCol_flt = make_tensor_like(filter_zeros(tCrCol)); + Tensor tCcCol_flt = filter_zeros(tCcCol, tCgCol.stride()); constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){}; constexpr int V = cute::min(Alignment, size(MCL)); if constexpr (V > 1) { - using VecType = uint_bit_t>; + using VecType = uint_bit_t>; Tensor tCgCol_vec = recast(coalesce(tCgCol_flt)); Tensor tCrCol_vec = recast(coalesce(tCrCol_flt)); Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int{}))); @@ -1057,12 +1309,23 @@ struct Sm90ColBroadcast { auto pred_fn = [&] (auto const&... coords) { return elem_less(tCcCol_flt(coords...), residue_tCcCol); }; copy_if(pred_fn, tCgCol_flt, tCrCol_flt); } + + constexpr int FrgSize = size(tCrCol_flt); + using FrgInput = Array; + using FrgCompute = Array; + using ConvertInput = NumericArrayConverter; + + Tensor tCrCol_input_frg = recast(coalesce(tCrCol_flt)); + Tensor tCrCol_compute_frg = recast(filter(tCrCol)); + ConvertInput convert_input{}; + + tCrCol_compute_frg(_0{}) = convert_input(tCrCol_input_frg(_0{})); } template - CUTLASS_DEVICE Array + CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_col; + Array frg_col; Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL @@ -1083,13 +1346,34 @@ struct Sm90ColBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + auto layout_M = [&] () { + auto shape_M = get<0>(args.problem_shape_mnkl); + if constexpr (IsDynamicBroadcast) { + auto stride_M = repeat_like(shape_M, int(0)); + if (get<0>(params.dCol) == bool(1)) { + stride_M = transform_leaf(compact_major(shape_M), + [] (auto const& stride) { return static_cast(stride); } + ); + } + return make_layout(shape_M, stride_M); + } + else { + return make_layout(shape_M); + } + }(); + + auto layout_N = make_layout(N, repeat_like(N, _0{})); + auto layout_L = make_layout(L, get<2>(params.dCol)); + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(layout_M,layout_N,layout_L)); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - return ConsumerStoreCallbacks( - cute::move(tCgCol), cute::move(tCrCol), args.tCcD, args.residue_tCcD, params); + Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(make_layout(M),layout_N,layout_L)); + Tensor tCgCol_static = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks(tCgCol, tCrCol, args.tCcD, args.residue_tCcD, params); } }; @@ -1110,6 +1394,20 @@ template < using Sm90MatrixBroadcast = Sm90AuxLoad; +namespace detail { + +template +struct IsScalarBroadcast { + static constexpr bool value = false; +}; + +template +struct IsScalarBroadcast(typename Operation::StrideMNL{})), Stride<_0,_0>>>> { + static constexpr bool value = true; +}; + +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::fusion diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index ae7b42b2bd..060f8d1594 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -286,6 +286,185 @@ struct Sm90AuxStore { } }; +template < + class Element, + class EpilogueTile, // Unused + FloatRoundStyle RoundStyle, + class LayoutOrStrideMNL, + class SmemLayoutAtom, // Unused + class CopyOpR2S, // Unused + int Alignment, + bool EnableNullptr +> +struct Sm90AuxStore< + 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, + SmemLayoutAtom, CopyOpR2S, Alignment, EnableNullptr +> { + using ElementAux = Element; + using StrideMNL = cutlass::gemm::TagToStrideC_t; + + struct SharedStorage { }; + + struct Arguments { + Element* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorR2G, + class RTensor, + class CTensorR2G, + class ProblemShapeMNL + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensorR2G&& tC_gAux, + RTensor&& tC_rAux, + CTensorR2G&& tC_cAux, + ProblemShapeMNL problem_shape_mnl, + Params const* params_ptr) + : tC_gAux(cute::forward(tC_gAux)), + tC_rAux(cute::forward(tC_rAux)), + tC_cAux(cute::forward(tC_cAux)), + problem_shape_mnl(problem_shape_mnl), + params_ptr(params_ptr) {} + + GTensorR2G tC_gAux; + RTensor tC_rAux; + CTensorR2G tC_cAux; + ProblemShapeMNL problem_shape_mnl; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); + tC_rAux_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_aux == nullptr) { + return; + } + } + + constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + + Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); + + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); + Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); + + auto pred_fn = [&] (auto const&... coords) { + return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); + }; + + copy_if(pred_fn, tC_rAux_vec, tC_gAux_vec); + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto problem_shape_mnl = make_shape(M,N,L); + + // Gmem Tensor + Tensor mAux = make_tensor( + make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux + ); + Tensor tC_gAux = sm90_partition_for_epilogue( + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + // Register Tensor + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); + + // Predication support + Tensor coordAux = make_identity_tensor(shape(mAux)); + Tensor tC_cAux = sm90_partition_for_epilogue( + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tC_gAux), + cute::move(tC_rAux), + cute::move(tC_cAux), + problem_shape_mnl, + params_ptr + ); + + } + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Reduction Store Operations @@ -304,10 +483,8 @@ template < > struct Sm90ScalarReduction { private: - static_assert( - (cute::is_same_v>) || // scalar reduction, e.g. tensor max element - (cute::is_same_v>) || // batched scalar reduction, e.g. per-batch max element - (cute::is_same_v>)); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); static constexpr bool IsAtomic = is_atomic>::value; static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); @@ -344,6 +521,7 @@ struct Sm90ScalarReduction { static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + #if !defined(CUTLASS_SKIP_REDUCTION_INIT) if constexpr (IsAtomic) { auto [M, N, K, L] = problem_shape; Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); @@ -351,6 +529,7 @@ struct Sm90ScalarReduction { return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter); } } + #endif return cutlass::Status::kSuccess; } @@ -480,15 +659,18 @@ template < // tensor of ElementCompute. It is the user's responsibility to reduce this to a (N, L) tensor of ElementOutput bool FinalReduction = true, // False means skip OOB predication if OOB inputs are known to be the reduction identity - bool VisitCheckOOB = true + bool VisitCheckOOB = true, + // Indicate the parameter order when calling RegReduceFn + // Seq length equals the number of RegReduceFn parameters + // No.0 represents tCrRow; No.1 and subsequent numbers sequentially represent frg_inputs in `visit` + class RegReduceSeq = cute::seq<0, 1> > struct Sm90RowReduction { private: static_assert(Stages == 0, "Smem usage not supported yet"); static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector reduction, e.g. per-col sum over all batches - (cute::is_same_v>)); // batched row vector reduction, e.g. per-col sum per batch + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); static constexpr bool IsAtomic = is_atomic>::value; static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); @@ -567,6 +749,7 @@ struct Sm90RowReduction { static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { +#if !defined(CUTLASS_SKIP_REDUCTION_INIT) if constexpr (IsAtomic) { auto [M, N, K, L] = problem_shape; Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); @@ -575,7 +758,9 @@ struct Sm90RowReduction { } return Status::kSuccess; } - else if constexpr (FinalReduction) { + else +#endif + if constexpr (FinalReduction) { auto [M, N, K, L] = problem_shape; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); @@ -626,14 +811,13 @@ struct Sm90RowReduction { Params const& params; bool do_final_reduction = false; - - template + template CUTLASS_DEVICE auto visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { + Array const&... frg_inputs) { if constexpr (EnableNullptr) { if (params.ptr_row == nullptr) { - return frg_input; + return cute::get<0>(cute::make_tuple(frg_inputs...)); } } @@ -643,21 +827,50 @@ struct Sm90RowReduction { Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); - using ConvertInput = NumericArrayConverter; - using ReduceInput = RegReduceFn; - ConvertInput convert_input{}; - ReduceInput reduce_input{}; + if constexpr (VisitCheckOOB) { + using ReduceInput = RegReduceFn; + ReduceInput reduce_input{}; - Array frg_I = convert_input(frg_input); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - if (!VisitCheckOOB || elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { - ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); - tCrRow_vmn = reduce_input(tCrRow_vmn, frg_I[i]); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { + ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); + tCrRow_vmn = transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + return ElementCompute(frg_input[i]); + }, + [&] (auto&&... cvt_frg_inputs) { + auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn, cvt_frg_inputs...); + return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); + }); + } } } + else { + constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); + using ReduceInput = RegReduceFn>; + ReduceInput reduce_input{}; + Tensor tCrRow_mn_frg = recast>(tCrRow_mn); - return frg_input; + constexpr int RegFragArraySize = FragmentSize / RegFragSize; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < RegFragArraySize; ++i) { + Array& tCrRow_vmn_frg = tCrRow_mn_frg(epi_v * RegFragArraySize + i); + tCrRow_vmn_frg = transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + using RegFragArr = Array, RegFragArraySize>; + ConvertInput convert_input{}; + return convert_input(reinterpret_cast(frg_input)[i]); + }, + [&] (auto&&... cvt_frg_inputs) { + auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn_frg, cvt_frg_inputs...); + return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); + }); + } + } + return cute::get<0>(cute::make_tuple(frg_inputs...)); } template @@ -683,23 +896,70 @@ struct Sm90RowReduction { return; } + int lane_m = get<0>(lane_mn); + [[maybe_unused]] bool is_reduced_lane = lane_m == 0; + // // 1. Warp shuffle reduction // using FragmentShuffle = Array; + Tensor tCrRow_frg = recast(filter(tCrRow)); using ReduceShuffle = ShuffleReduceFn; ReduceShuffle reduce_shuffle{}; - Tensor tCrRow_frg = recast(filter(tCrRow)); - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { + + auto FrgSizePerLaneM = size(tCrRow_frg) / size<0>(lane_layout_MN); + constexpr bool SwapShuffle = FrgSizePerLaneM > 0; + + // + // Swap Shuffle + // + // The normal way to reduction among threads: + // use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads. + // After each step of reduction, a half of threads won't work in the following steps. + // That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case). + // + // To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors, + // we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads. + // After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step. + // We can recursively do this until the problem size is 1. + // + if constexpr (SwapShuffle) { // for a NxN matrix to be reduced among N threads as a 1XN vectors + Tensor tCrRow_frg_ = logical_divide(tCrRow_frg, FrgSizePerLaneM); // (FrgSizePerLaneM, M) + CUTLASS_PRAGMA_UNROLL + for (int m = size<1>(tCrRow_frg_) / 2; m > 0; m /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int r = 0; r < m; ++r) { + auto frg_A = tCrRow_frg_(_,r); + auto frg_B = tCrRow_frg_(_,r + m); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size(frg_A); ++v) { + // Step1: swap + if (not (lane_m & m)) { // the first half of threads swap fragments from the first half of data to the second + swap(frg_A(v), frg_B(v)); + } + + // Step2: shuffle + uint64_t frg_shfl = reinterpret_cast(frg_A(v)); + // each half of threads get a half of data from the other half of threads + frg_shfl = __shfl_xor_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(m, _0{})); + + // Step3: reduction + frg_A(v) = reduce_shuffle(frg_B(v), reinterpret_cast(frg_shfl)); + } + } + } + } + else { CUTLASS_PRAGMA_UNROLL - for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { - uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); - frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); - tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); + for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); + tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); + } } } - bool is_reduced_lane = get<0>(lane_mn) == 0; // // 2. Atomic reduction @@ -708,6 +968,7 @@ struct Sm90RowReduction { // Filter so we don't issue redunant copies over stride-0 modes Tensor tCrRow_flt = filter_zeros(tCrRow); Tensor tCcRow_flt = make_tensor(tCcRow.data(), make_layout(tCrRow_flt.shape(), tCcRow.stride())); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); Tensor tCgRow = sm90_partition_for_epilogue(gRow_l(_,_,l), epi_tile, tiled_copy, thread_idx); Tensor tCgRow_flt = filter_zeros(tCgRow); @@ -717,11 +978,23 @@ struct Sm90RowReduction { ConvertOutput convert_output{}; ReduceOutput reduce_output{}; - if (is_reduced_lane) { + if constexpr (SwapShuffle) { CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrRow_flt); ++i) { - if (elem_less(tCcRow_flt(i), residue_tCcRow)) { - reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); + for (int i = 0; i < FltFrgSizePerLaneM; ++i) { + int idx = lane_m * FltFrgSizePerLaneM + i; + // Only care about OOB for N mode + if (get<1>(tCcRow_flt(idx)) < get<1>(residue_tCcRow)) { + reduce_output(&tCgRow_flt(idx), convert_output(tCrRow_flt(i))); + } + } + } + else { + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrRow_flt); ++i) { + if (elem_less(tCcRow_flt(i), residue_tCcRow)) { + reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); + } } } } @@ -735,10 +1008,21 @@ struct Sm90RowReduction { // Dump warp reduction to gmem workspace using ElementGmem = cute::conditional_t; Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_ml(_,_,m,l), epi_tile, tiled_copy, thread_idx); - if (is_reduced_lane) { - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrRow), recast(filter(tCgBuf))); + + if constexpr (SwapShuffle) { + Tensor tCrRow_flt = filter(tCrRow); + Tensor tCgBuf_flt = recast(filter(tCgBuf)); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + Tensor tCgBuf_flt_ = logical_divide(tCgBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + copy_aligned(tCrRow_flt_(_,_0{}), tCgBuf_flt_(_,lane_m)); + } + else { + if (is_reduced_lane) { + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCrRow), recast(filter(tCgBuf))); + } } sync_fn(); } @@ -755,10 +1039,21 @@ struct Sm90RowReduction { // Dump warp reduction to smem workspace Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<0>(warp_mn)), epi_tile, tiled_copy, thread_idx); - if (is_reduced_lane) { - // Filter so we don't issue redunant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrRow), filter(tCsBuf)); + + if constexpr (SwapShuffle) { + Tensor tCrRow_flt = filter(tCrRow); + Tensor tCsBuf_flt = filter(tCsBuf); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + Tensor tCsBuf_flt_ = logical_divide(tCsBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + copy_aligned(tCrRow_flt_(_,_0{}), tCsBuf_flt_(_,lane_m)); + } + else { + if (is_reduced_lane) { + // Filter so we don't issue redunant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_aligned(filter(tCrRow), filter(tCsBuf)); + } } sync_fn(); @@ -772,25 +1067,30 @@ struct Sm90RowReduction { Tensor sBuf_vec = recast(filter_zeros(sBuf)); constexpr int FragsPerRow = decltype(size<1>(sBuf_frg))::value; - // Do the threadblock smem reduction - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = size<0>(warp_layout_MN) / 2; reduction_rows > 1; reduction_rows /= 2) { - int FragsPerReduction = reduction_rows * FragsPerRow; - CUTLASS_PRAGMA_NO_UNROLL - for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { - FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); - sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); - } - sync_fn(); - } + constexpr int RowNum = decltype(size<0>(warp_layout_MN))::value; + using FragmentSmemArray = Array; - // Do final smem reduction and dump to gmem workspace + // Do the threadblock smem reduction using VectorGmem = cute::conditional_t; Tensor gBuf_vec = recast(filter(gBuf_ml(_,_,m,l))); - CUTLASS_PRAGMA_NO_UNROLL + CUTLASS_PRAGMA_UNROLL for (int frg_idx = thread_idx; frg_idx < FragsPerRow; frg_idx += size(tiled_copy)) { - FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerRow)); - gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + FragmentSmemArray frg_smem; + + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = 0; reduction_rows < RowNum; ++reduction_rows) { + int FragsCurrRows = reduction_rows * FragsPerRow; + frg_smem[reduction_rows] = sBuf_frg(FragsCurrRows + frg_idx); + } + + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = RowNum / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int row_idx = 0; row_idx < reduction_rows; ++row_idx) { + frg_smem[row_idx] = reduce_smem(frg_smem[row_idx], frg_smem[row_idx + reduction_rows]); + } + } + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem[0]); } sync_fn(); } @@ -959,9 +1259,8 @@ struct Sm90ColReduction { private: static_assert(Stages == 0, "Smem usage not supported yet"); static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // col vector reduction, e.g. per-row sum over all batches - (cute::is_same_v>)); // batched col vector reduction, e.g. per-row sum per batch + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); static constexpr bool IsAtomic = is_atomic>::value; static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); @@ -1042,6 +1341,7 @@ struct Sm90ColReduction { static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { +#if !defined(CUTLASS_SKIP_REDUCTION_INIT) if constexpr (IsAtomic) { auto [M, N, K, L] = problem_shape; Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); @@ -1050,7 +1350,9 @@ struct Sm90ColReduction { } return Status::kSuccess; } - else if constexpr (FinalReduction) { + else +#endif + if constexpr (FinalReduction) { auto [M, N, K, L] = problem_shape; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 843640127d..4f7d99fa32 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -170,7 +170,7 @@ struct ConsumerStoreArgs { Residue residue_cD; ThrCoordTensor tCcD; ThrResidue residue_tCcD; - ThrSrcTensor const& tCrC; + ThrSrcTensor & tCrC; int thread_idx; CUTLASS_DEVICE @@ -185,7 +185,7 @@ struct ConsumerStoreArgs { Residue residue_cD, ThrCoordTensor tCcD, ThrResidue residue_tCcD, - ThrSrcTensor const& tCrC, + ThrSrcTensor & tCrC, int thread_idx) : problem_shape_mnkl(problem_shape_mnkl), tile_shape_mnk(tile_shape_mnk), @@ -361,14 +361,12 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables CallbacksTuple callbacks_tuple; - // Before entry of the subtile load loop. Bulk copies usually performed here. - // Upon entry the producer_acquire of the first subtile lock has completed. - // full_mbarrier_ptr is the corresponding barrier for the subsequent producer_commit arrival + // Before entry of the subtile load loop CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { + begin() { for_each(callbacks_tuple, [&] (auto& callbacks) { - callbacks.begin(full_mbarrier_ptr, load_iteration, issue_tma_load); + callbacks.begin(); } ); } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp new file mode 100644 index 0000000000..53c0dce8ba --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp @@ -0,0 +1,759 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree Top-K + Softmax fusion operation for sm90 TMA warp-specialized epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Top-K + Softmax reduction across columns +// Performs a reduction of top-K values across N, and finally performs a softmax on them, +// and sets values not in the top-K to 0. +// +// Assumptions: +// 1. CTA_N >= N (single tile across N, the mode which is reduced) +// 2. EPI_N >= N (single epilogue tile across N, because we can reduce and revisit one +// epilogue tile at a time.) +// 3. Top-K value is either 2 or 4. +// + +namespace detail { + +// Implementations for add to sorted list and merging sorted lists, +// with fast paths for lists of size 2 and 4 (Top-2 and Top-4). +// Generic implementations may result in greater register use and branching, +// and should be avoided. +// Fast paths for Top-2 and Top-4 are written in inline PTX directly. + +CUTLASS_DEVICE +Array top_2_reduce_scalar(Array a, float scalar) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mx;\n" + " .reg .pred p;\n" + " max.f32 mx, %3, %4;\n" + " setp.gtu.f32 p, %2, %4;\n" + " selp.f32 %1, mx, %2, p;\n" + " selp.f32 %0, %2, %4, p;\n" + "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar)); + return out; +} + +CUTLASS_DEVICE +Array top_2_reduce(Array a, Array b) { + Array out; + asm volatile( + "{\n" + " .reg .v2 .f32 mx;\n" + " .reg .pred p;\n" + " max.f32 mx.x, %3, %4;\n" // max(a1, b0) + " max.f32 mx.y, %2, %5;\n" // max(a0, b1) + " setp.gtu.f32 p, %2, %4;\n" // a0 > b0 + " selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1) + " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 + "}\n" : "=f"(out[0]), "=f"(out[1]) : + "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); + return out; +} + +CUTLASS_DEVICE +Array top_4_reduce_scalar(Array a, float scalar) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mx;\n" // max(a3, b) + " .reg .pred p0;\n" // a0 > b + " .reg .pred p1;\n" // a1 > b + " .reg .pred p2;\n" // a2 > b + " max.f32 mx, %7, %8;\n" // max(a3, b) + " setp.gtu.f32 p0, %4, %8;\n" // a0 > b + " setp.gtu.f32 p1, %5, %8;\n" // a1 > b + " setp.gtu.f32 p2, %6, %8;\n" // a2 > b + " selp.f32 %3, mx, %6, p2;\n" // a2 > b ? max(a3, b) : a2 + " selp.f32 %2, %6, %8, p2;\n" // a1 = a2 > b ? a2 : b + " selp.f32 %2, %2, %5, p1;\n" // a1 > b ? max(a2, b) : a1 == a1 > b ? a1 : old_a1 + " selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b + " selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0 + " selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); + return out; +} + +CUTLASS_DEVICE +Array top_4_reduce(Array a, Array b) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mxa0b1;\n" // max(a0, b1) + " .reg .f32 mxa1b0;\n" // max(a1, b0) + + " .reg .f32 mxa2b0;\n" // max(a2, b0) + " .reg .f32 mxa1b1;\n" // max(a1, b1) + " .reg .f32 mxa0b2;\n" // max(a1, b1) + + " .reg .f32 mxa1b2;\n" // max(a1, b2) + " .reg .f32 mxa2b1;\n" // max(a2, b1) + " max.f32 mxa1b2, %5, %10;\n" + " max.f32 mxa2b1, %6, %9;\n" + + " .reg .f32 mxa3b0;\n" // max(a1, b2) + " .reg .f32 mxa0b3;\n" // max(a2, b1) + " max.f32 mxa3b0, %7, %8;\n" + " max.f32 mxa0b3, %4, %11;\n" + + " .reg .pred pa0b0;\n" // a0 > b0 + " .reg .pred pa1b0;\n" // a1 > b0 + " .reg .pred pa2b0;\n" // a2 > b0 + " .reg .pred pa0b1;\n" // a0 > b1 + " .reg .pred pa1b1;\n" // a1 > b1 + " .reg .pred pa0b2;\n" // a0 > b2 + " .reg .pred pb2a0;\n" // b1 > a0 + " .reg .pred pb1a0;\n" // b1 > a0 + + " setp.gtu.f32 pa0b0, %4, %8;\n" // a0 > b0 + " setp.gtu.f32 pa1b0, %5, %8;\n" // a1 > b0 + " setp.gtu.f32 pa2b0, %6, %8;\n" // a2 > b0 + " setp.gtu.f32 pa0b1, %4, %9;\n" // a0 > b1 + " setp.gtu.f32 pa1b1, %5, %9;\n" // a1 > b1 + " setp.gtu.f32 pa0b2, %4, %10;\n" // a0 > b2 + + " not.pred pb2a0, pa0b2;\n" + " not.pred pb1a0, pa0b1;\n" + + " selp.f32 mxa1b0, %5, %8, pa1b0;\n" // max(a1, b0) + " selp.f32 mxa0b1, %4, %9, pa0b1;\n" // max(a0, b1) + + " selp.f32 mxa1b1, %5, %9, pa1b1;\n" // max(a1, b1) + " selp.f32 mxa2b0, %6, %8, pa2b0;\n" // max(a2, b0) + " selp.f32 mxa0b2, %4, %10, pa0b2;\n" // max(a0, b2) + + // a0 + " selp.f32 %0, %4, %8, pa0b0;\n" // a0 = a0 > b0 ? a0 : b0 + + // a1 + " selp.f32 %1, mxa1b0, mxa0b1, pa0b0;\n" // a1 = a0 > b0 ? max(a1, b0) : max(a0, b1) + + // a2 + " mov.f32 %2, mxa1b1;\n" // a2 = max(a1, b1) ** most likely case + " selp.f32 %2, mxa2b0, %2, pa1b0;\n" // a0 > a1 > b0 + " selp.f32 %2, mxa0b2, %2, pb1a0;\n" // b0 > b1 > a0 + + // a3 + " mov.f32 %3, mxa1b2;\n" // a3 = max(a1, b2) ** one of the most likely cases + " selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case + " selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0 + " selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0 + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), + "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); + return out; +} + +// Assumption: array elements are sorted in descending order +// (a[0] is the largest element in a[].) +template +CUTLASS_DEVICE +void add_element_to_desc_sorted_array(cutlass::Array& a, Element b) { + if constexpr (N == 2 && is_same_v) { + a = top_2_reduce_scalar(a, b); + } + else if constexpr (N == 4 && is_same_v) { + a = top_4_reduce_scalar(a, b); + } + else { + // slower generic path with branching, slower, and can cause register spill + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (a[k] <= b) { + // Shift down + CUTLASS_PRAGMA_UNROLL + for (int l = N - 1; l > k; --l) { + a[l] = a[l-1]; + } + a[k] = b; + } + } + } +} + +// Assumption: array elements are sorted in descending order +// (a[0] and b[0] are the largest elements in a[] and b[].) +template +CUTLASS_DEVICE +void merge_desc_sorted_arrays(cutlass::Array& a, const cutlass::Array& b) { + if constexpr (N == 2 && is_same_v) { + a = top_2_reduce(a, b); + } + else if constexpr (N == 4 && is_same_v) { + a = top_4_reduce(a, b); + } + else { + // slower generic path with branching, slower, and can cause register spill + int j = 0; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (a[k] <= b[j]) { + // Shift down + CUTLASS_PRAGMA_UNROLL + for (int l = N - 1; l > k; --l) { + a[l] = a[l-1]; + } + a[k] = b[j]; + ++j; + } + } + } +} + +// Assumption: array elements are sorted in descending order +// (a[0] is the largest element in a[].) +template +CUTLASS_DEVICE +Element topk_logsumexp(cutlass::Array a) { + // Do one less `exp`, because we know what its result will be. + // Assume x is a set of `x_i`s, and `x_m` is the maximum of that set. + // logsumexp(x) = log(sum(x_i)) = m + log(sum(x_i - m)) = m + log(1 + sum_{i != m}(x_i - x_m)) + // Compute m + log(1 + sum_{i != m}(x_i - x_m)) + Element sum = Element(1.0); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < N; ++i) { + sum += fast_exp(a[i] - a[0]); + } + return a[0] + fast_log(sum); +} + +CUTLASS_DEVICE +float fast_masked_softmax(float value, float minimum, float logsumexp) { + float new_value; + asm volatile( + "{\n" + " .reg .pred p0;\n" + // value >= minimum + " setp.geu.f32 p0, %1, %2;\n" + + " .reg .f32 x_lse;\n" + " .reg .f32 %%f<11>;\n" + " .reg .b32 %%r<3>;\n" + + // x_lse = value - minimum + " sub.rn.f32 x_lse, %1, %3;\n" + + // exp(x_lse) + // The following is derived from a ptx dump of expf. + // exp requires a base conversion from exp2. + " fma.rn.f32 %%f1, x_lse, 0f3BBB989D, 0f3F000000;\n" + " cvt.sat.f32.f32 %%f2, %%f1;\n" + " fma.rm.f32 %%f3, %%f2, 0f437C0000, 0f4B400001;\n" + " add.f32 %%f4, %%f3, 0fCB40007F;\n" + " neg.f32 %%f5, %%f4;\n" + " fma.rn.f32 %%f6, x_lse, 0f3FB8AA3B, %%f5;\n" + " fma.rn.f32 %%f7, x_lse, 0f32A57060, %%f6;\n" + " mov.b32 %%r1, %%f3;\n" + " shl.b32 %%r2, %%r1, 23;\n" + " mov.b32 %%f8, %%r2;\n" + " ex2.approx.ftz.f32 %%f9, %%f7;\n" + " mul.f32 %%f10, %%f9, %%f8;\n" + + // Mask or softmax + " selp.f32 %0, %%f10, 0f00000000, p0;\n" + "}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp)); + return new_value; +} + +template +CUTLASS_DEVICE +Element masked_softmax(Element value, Element minimum, Element logsumexp) { + if constexpr (is_same_v) { + // Inline PTX implementation + // Significantly reduces register requirements + return fast_masked_softmax(value, minimum, logsumexp); + } + else { + return value < minimum ? Element(0.0) : fast_exp(value - logsumexp); + } +} + +} // namespace detail + +template < + int TopK, + int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + int Alignment = 128 / sizeof_bits_v, + bool UseButterflyReduce = true +> +struct Sm90TopKSoftmaxColReduction { +private: + static_assert(is_same_v, "Fused Top-K + Softmax reduction requires FP32 accumulation."); + static_assert(TopK == 2 || TopK == 4, "Fused Top-K + Softmax reduction only supports K=2 and K=4."); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + // Reduction tensors + // We have two tensors for this EVT node: a reduction tensor and a tensor holding + // final reduction values (tCrSoftmax). The reason for this is that Top-K and Softmax + // require different reductions, but those luckily overlap. Top-K obviously needs at least + // two values (K >= 2), and softmax needs one value: logsumexp. Logsumexp is simply the log + // of sum of exponents over the set, and is equivalent to m + sum(exp(x_i - m)), where m is the + // maximum of all x_i elements. Since safe softmax for any element x_i is computed as + // softmax(x_i) = exp(x_i - m) / sum_j(exp(x_j - max)) + // we can track logsumexp instead of tracking two variables (sum of exps and the max). + // In addition, subtracting logsumexp from any element and taking its exp is equivalent to + // computing its softmax. + // + // The overlap between softmax and top-K is that we don't need to reduce logsumexp along the + // way at all, because any element not in the top-K is going to be masked out and set to 0. + // Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and + // keep it, and the smallest element in the top-K for masking out non-top-K elements. + // + // This means that our final reduction result will always be 2 elements, regardless of the value + // of K: minimum of top-K, and logsumexp. + // + // For each reduction tensor, we define a new struct for readability. + + struct ReductionResult { + ElementCompute min_; + ElementCompute logsumexp_; + + CUTLASS_DEVICE + ReductionResult() { } + + CUTLASS_DEVICE + ReductionResult(ElementCompute min, ElementCompute logsumexp): + logsumexp_(logsumexp), min_(min) { } + + // Warp shuffle broadcast + CUTLASS_DEVICE + void shuffle_up_sync(uint32_t delta, int lane_id) { + static_assert(sizeof(ReductionResult) == sizeof(uint64_t)); + uint64_t r = reinterpret_cast(*this); + r = __shfl_up_sync(0xFFFFFFFF, r, delta); + *this = (lane_id - static_cast(delta) >= 0) ? reinterpret_cast(r) : *this; + } + }; + + struct TopKResult { + Array top_k_; + + CUTLASS_DEVICE + TopKResult() { + top_k_.fill(-cutlass::platform::numeric_limits::infinity()); + } + + // This is where we do the "final" reduction, where we compute + // the logsumexp for softmax, keep the smallest value in top-K, + // and discard the rest. + CUTLASS_DEVICE + ReductionResult reduce_final() const { + return ReductionResult(top_k_[TopK - 1], topk_logsumexp(top_k_)); + } + + // Butterfly reduction + CUTLASS_DEVICE + void shuffle_xor_sync(int laneMask) { + if constexpr (TopK == 2) { + static_assert(sizeof(TopKResult) == sizeof(uint64_t)); + uint64_t top_k = reinterpret_cast(*this); + top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask); + auto synced_v = reinterpret_cast(top_k); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else if constexpr (TopK == 4) { + static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); + uint64_t* top_k_ptr = reinterpret_cast(this); + uint64_t top_k_arr[2]; + top_k_arr[0] = top_k_ptr[0]; + top_k_arr[1] = top_k_ptr[1]; + top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask); + top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask); + auto synced_v = reinterpret_cast(top_k_arr); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else { + TopKResult synced_v; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TopK; ++i) { + synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask); + } + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + } + + // Warp shuffle reduction + CUTLASS_DEVICE + void shuffle_down_sync(uint32_t delta) { + if constexpr (TopK == 2) { + static_assert(sizeof(TopKResult) == sizeof(uint64_t)); + uint64_t top_k = reinterpret_cast(*this); + top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta); + auto synced_v = reinterpret_cast(top_k); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else if constexpr (TopK == 4) { + static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); + uint64_t* top_k_ptr = reinterpret_cast(this); + uint64_t top_k_arr[2]; + top_k_arr[0] = top_k_ptr[0]; + top_k_arr[1] = top_k_ptr[1]; + top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta); + top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta); + auto synced_v = reinterpret_cast(top_k_arr); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else { + TopKResult synced_v; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TopK; ++i) { + synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta); + } + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + } + }; + +public: + struct SharedStorage { }; + + struct Arguments { }; + + struct Params { }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return {}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Cross CTA reduction is not possible because there is no guarantee that all CTAs run + // concurrently. + // Cross epilogue tile reduction is possible, but re-visiting and applying reduction + // to accumulators is only possible for the current epilogue tile. + auto [epi_M, epi_N] = EpilogueTile{}; + return N <= tile_N && N <= epi_N && N >= TopK; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90TopKSoftmaxColReduction() { } + + CUTLASS_HOST_DEVICE + Sm90TopKSoftmaxColReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Array frg_I = convert_input(frg_input); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + auto thread_crd = tCcCol_mn(epi_v * FragmentSize + i); + if (elem_less(thread_crd, residue_tCcCol)) { + TopKResult& tCrCol_vmn = tCrTopK(epi_v * FragmentSize + i); + detail::add_element_to_desc_sorted_array(tCrCol_vmn.top_k_, frg_I[i]); + } + } + + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { + return; + } + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + // `tCrTopK` and `tCrSoftmax` have 0-strides along modes that correspond to N, + // in order to reduce along modes in the `R2S` sublayout that correspond to N. + // This means we should modify and warp-reduce them according to their co-domain instead of + // their domain. Therefore we keep a filtered view of both and use them as necessary. + auto tCrTopK_f = filter(tCrTopK); + auto tCrSoftmax_f = filter(tCrSoftmax); + + // The pattern here is: reduce Top-K first, then compute logsumexp, keep it and the + // last element of Top-K, use the latter to mask the visited results, and the former + // to apply softmax. + // + // This gives us two options: reduce the Top-K with warp shuffles, have the reduced + // lanes compute logsumexp and pair it with the last Top-K element, and broadcast + // the result back using warp shuffles. + // + // Alternatively, we can do a butterfly reduction over Top-K, and have all lanes + // compute their own logsumexp and skip the broadcast. + if constexpr (UseButterflyReduce) { + // + // 1. Butterfly reduction + // + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < size<1>(lane_layout_MN); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrTopK_f); ++i) { + tCrTopK_f(i).shuffle_xor_sync(j); + } + } + + // + // 2. Strip down reduced value and compute sum of exps + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); + } + } + else { + // + // 1. Warp shuffle reduction + // + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrTopK_f); ++i) { + tCrTopK_f(i).shuffle_down_sync(lane_layout_MN(_0{},reduction_cols)); + } + } + + // + // 2. Strip down reduced value and compute sum of exps + // + bool is_reduced_lane = get<1>(lane_mn) == 0; + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); + } + } + + // + // 3. Broadcast reduced values to all participants + // + CUTLASS_PRAGMA_UNROLL + for (int broadcast_cols = 1; broadcast_cols <= size<1>(lane_layout_MN) / 2; broadcast_cols *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i).shuffle_up_sync(lane_layout_MN(_0{},broadcast_cols), get<1>(lane_mn)); + } + } + } + + // + // 4. Re-visit and apply top-K and softmax + // + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(visit_results); ++epi_v) { + auto& visit_frag = visit_results(epi_v); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + visit_frag[i] = detail::masked_softmax( + visit_frag[i], + tCrSoftmax(epi_v * FragmentSize + i).min_, + tCrSoftmax(epi_v * FragmentSize + i).logsumexp_ + ); + } + } + + } + + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + + // Reset reduced top-K values for next tile + // This must be done because we only assume a single epilogue tile across N, + // but not M. + fill(tCrTopK, TopKResult()); + } + + CUTLASS_DEVICE void + end() { } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } + else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + + // Make sure there's only one warp across N so we can use warp shuffle intrinsics for reduction. + static_assert(decltype(size<1>(warp_layout_MN))::value <= 1); + + // Reduction layout + // We're assuming all elements in a row (over which we're performing the reduction) are + // visited in the same corresponding epilogue tile, and this is what allows us to apply the + // top-K + softmax operation within `reduce()`, by re-visiting the accumulated results. + // + // This presents a challenge, because the layout of the accumulated results is typically in + // in the register to shared memory shape, or: (R2S,R2S_M,R2S_N). + // This means that we still need to reduce this tensor along N. + // + // The solution is simple: we need to flatten the layout, identify modes that correspond to + // N and set their strides to 0, in order to map fragment indices corresponding to the same + // row back to the same element in the tensor. + // + // This requires some extra layout manipulation, which is as follows. + + // Create new accumulator layout with column broadcast + auto [M, N, K] = args.tile_shape_mnk; + auto thr_mma = args.tiled_mma.get_thread_slice(args.thread_idx); + auto gColReduce = make_tensor( + make_layout(make_shape(M, N), make_stride(_1{}, 0_c))); // (M,N) + auto tCrColReduce = make_tensor_like( // (FrgV, MMA_M, MMA_N) + thr_mma.partition_C(gColReduce).layout()); + + // Tile the new accumulator tensor according to R2S + ThrCopy thread_r2s = args.tiled_copy.get_slice(args.thread_idx); + Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N) + auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N) + + // Compose the new accumulator R2S layout with the expected tCrC layout to get final + // reduction tensor layout. + auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N) + + Tensor tCrTopK = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) + Tensor tCrSoftmax = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) + fill(tCrTopK, TopKResult()); + + auto args_tuple = make_tuple( + cute::move(tCrTopK), cute::move(tCrSoftmax), args.tCcD, args.cD, + lane_layout_MN, lane_mn, + args.residue_cD, args.residue_tCcD); + return ConsumerStoreCallbacks(std::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 92407733f8..9f1cd77434 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -178,8 +178,9 @@ struct Clamp { CUTLASS_HOST_DEVICE T operator()(T const& value, T const& lower_bound, T const& upper_bound) const { - maximum mx; - minimum mn; + constexpr bool PropagateNaN = true; + maximum mx; + minimum mn; return mn(mx(value, lower_bound), upper_bound); } @@ -196,8 +197,9 @@ struct Clamp> { CUTLASS_HOST_DEVICE Array operator()(Array const& values, T const& lower_bound, T const& upper_bound) const { - maximum> mx; - minimum> mn; + constexpr bool PropagateNaN = true; + maximum, PropagateNaN> mx; + minimum, PropagateNaN> mn; return mn(mx(values, lower_bound), upper_bound); } @@ -226,7 +228,7 @@ struct LeakyReLU { CUTLASS_HOST_DEVICE T operator()(T const& value, Arguments const& args = Arguments()) const { - this->operator()(value, args.leaky_alpha); + return this->operator()(value, args.leaky_alpha); } }; @@ -696,6 +698,57 @@ struct dReLU_Z> { } }; +// ElementwiseFilter operator +// Filters by a specific value and maps it to 0.0 +// Used in GEMM + comm +template +struct ElementwiseFilter { + + static const bool kIsHeavy = false; + + struct Arguments { + T value_to_filter = T(-0.0); + T filtered_value = T(0.0); + }; + + CUTLASS_HOST_DEVICE + T operator()(T const& value, T const& value_to_filter, T const& filtered_value) const { + T res = value == value_to_filter ? filtered_value : value; + return res; + } + + CUTLASS_HOST_DEVICE + T operator()(T const& value, Arguments const& args = Arguments()) const { + return this->operator()(value, args.value_to_filter, args.filtered_value); + } +}; + +template +struct ElementwiseFilter > { + + static const bool kIsHeavy = false; + + using Arguments = typename ElementwiseFilter::Arguments; + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, T const& value_to_filter, T const& filtered_value) const { + Array y; + ElementwiseFilter filter_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(values.size()); ++i) { + y[i] = filter_op(values[i], value_to_filter, filtered_value); + } + + return y; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, Arguments const& args = Arguments()) const { + return this->operator()(values, args.value_to_filter, args.filtered_value); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 7456ae8df4..c5ffdaa03f 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -127,15 +127,20 @@ class LinearCombinationBiasElementwise { public: using ElementOutput = ElementC_; + using ElementD = ElementOutput; using ElementC = ElementC_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; + using ElementScalar = ElementCompute; using ElementZ = ElementZ_; using ElementT = ElementT_; using ElementVector = ElementVector_; static int const kElementsPerAccess = ElementsPerAccess; static int const kCount = kElementsPerAccess; + /// Follow cutlass3x EVT aliases + static bool const IsEltActSupported = true; + using ElementwiseOp = ElementwiseOp_; using BinaryOp = BinaryOp_; @@ -157,7 +162,7 @@ class LinearCombinationBiasElementwise { using FragmentOutput = FragmentZ; using ElementBias = ElementVector; using FragmentBias = Array; - using ActivationFunctor = ElementwiseOp; + using ActivationFn = ElementwiseOp; static const ScaleType::Kind kScale = ScaleType::Default; static bool const kIsHeavy = kIsHeavy_member_or_false::value; @@ -396,6 +401,118 @@ class LinearCombinationBiasElementwise { frag_T = convert_t(result_T); } } + + /// Applies the operation when elementwise_op require arguments and is_source_needed() is true + template + CUTLASS_HOST_DEVICE + void operator()( + ElementZ &Z, + ElementT &T, + ElementAccumulator const &AB, + ElementC const &C, + ElementCompute const &V, + ElementwiseArgs const &elementwise_args) const { + + ElementwiseOp elementwise_op; + BinaryOp binary_op; + + ElementCompute tmp_Accum = NumericConverter()(AB); + ElementCompute tmp_C = NumericConverter()(C); + + ElementCompute z = binary_op(alpha_ * tmp_Accum + beta_ * tmp_C, V); + ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args); + + NumericConverter convert_z; + Z = convert_z(result_Z); + + if constexpr (kStoreT) { + ElementCompute result_T = z; + NumericConverter convert_t; + T = convert_t(result_T); + } + } + + /// Applies the operation when elementwise_op require arguments and is_source_needed() is false + template + CUTLASS_HOST_DEVICE + void operator()( + ElementZ &Z, + ElementT &T, + ElementAccumulator const &AB, + ElementCompute const &V, + ElementwiseArgs const &elementwise_args) const { + + ElementwiseOp elementwise_op; + BinaryOp binary_op; + + ElementCompute tmp_Accum = NumericConverter()(AB); + + ElementCompute z = binary_op(alpha_ * tmp_Accum, V); + ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z, elementwise_args); + + NumericConverter convert_z; + Z = convert_z(result_Z); + + if constexpr (kStoreT) { + ElementCompute result_T = z; + NumericConverter convert_t; + T = convert_t(result_T); + } + } + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()( + ElementZ &Z, + ElementT &T, + ElementAccumulator const &AB, + ElementC const &C, + ElementCompute const &V) const { + + ElementwiseOpDispatcher elementwise_op(elementwise_); + BinaryOp binary_op; + + ElementCompute tmp_Accum = NumericConverter()(AB); + ElementCompute tmp_C = NumericConverter()(C); + + ElementCompute z = binary_op(alpha_ * tmp_Accum + beta_ * tmp_C, V); + ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z); + + NumericConverter convert_z; + Z = convert_z(result_Z); + + if constexpr (kStoreT) { + ElementCompute result_T = z; + NumericConverter convert_t; + T = convert_t(result_T); + } + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()( + ElementZ &Z, + ElementT &T, + ElementAccumulator const &AB, + ElementCompute const &V) const { + + ElementwiseOpDispatcher elementwise_op(elementwise_); + BinaryOp binary_op; + + ElementCompute tmp_Accum = NumericConverter()(AB); + + ElementCompute z = binary_op(alpha_ * tmp_Accum, V); + ElementCompute result_Z = skip_elementwise_ ? z : elementwise_op(z); + + NumericConverter convert_z; + Z = convert_z(result_Z); + + if constexpr (kStoreT) { + ElementCompute result_T = z; + NumericConverter convert_t; + T = convert_t(result_T); + } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index 1692cc3093..1d62f4fc35 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -225,6 +225,44 @@ struct DefaultIteratorsTensorOp< static int const kFragmentsPerIteration = 2; }; +/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template < + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + bfloat16_t, + int32_t, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + WarpShape, + InstructionShape, + int32_t, + 32, + 16, + 8, + 8 + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + int32_t, + 32, + 16, + 8, + 8 + >; + + static int const kFragmentsPerIteration = 2; +}; + /// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts. template < typename ThreadblockShape, diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 5709ec9fed..38ea4008c2 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -574,6 +574,12 @@ struct alignas(1) float_e4m3_t : float8_base { int mantissa() const { return int(storage & Base::FP8_MANTISSA_MASK); } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_e4m3_t const& x) { + return x.storage == uint8_t(0x7f); + } + }; /////////////////////////////////////////////////////////////// /// @@ -783,6 +789,12 @@ struct alignas(1) float_e5m2_t : float8_base { int mantissa() const { return int(storage & Base::FP8_MANTISSA_MASK); } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_e5m2_t const& x) { + return x.storage == uint8_t(0x7f); + } + }; /////////////////////////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 65e49d5290..5b2bc3c67f 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -38,7 +38,6 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/platform/platform.h" - #if defined(__CUDACC_RTC__) #include "cutlass/floating_point_nvrtc.h" #endif @@ -234,7 +233,7 @@ template <> struct inverse_square_root { CUTLASS_HOST_DEVICE half_t operator()(half_t const &lhs) const { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 520 +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 520) auto result = hrsqrt(reinterpret_cast<__half const &>(lhs)); return reinterpret_cast(result); #else @@ -350,7 +349,19 @@ template struct maximum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { - return (lhs < rhs ? rhs : lhs); + if constexpr (PropagateNaN && cutlass::platform::is_floating_point::value) { + using CUTLASS_CMATH_NAMESPACE :: isnan; + + // Call isnan unqualified, so argument-dependent lookup (ADL) + // will find overloads such as cutlass::isnan(half_t). + // Calling ::isnan or std::isnan directly would force + // implicit conversions to float of custom number types + // in the cutlass namespace (e.g., cutlass::half_t). + return lhs > rhs || isnan(lhs) ? lhs : rhs; + } + else { + return (lhs < rhs ? rhs : lhs); + } } }; @@ -363,23 +374,6 @@ template struct maximum_with_default_nan_propagation : public maximum {}; -// Maximum with nan propagation -// To propagate NANs, the "max" of a two element that contains NaNs should also return a NaN -template -struct maximum { - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { - using CUTLASS_CMATH_NAMESPACE :: isnan; - - // Call isnan unqualified, so argument-dependent lookup (ADL) - // will find overloads such as cutlass::isnan(half_t). - // Calling ::isnan or std::isnan directly would force - // implicit conversions to float of custom number types - // in the cutlass namespace (e.g., cutlass::half_t). - return lhs > rhs || isnan(lhs) ? lhs : rhs; - } -}; - template <> struct maximum { CUTLASS_HOST_DEVICE @@ -391,13 +385,14 @@ struct maximum { template <> struct maximum { CUTLASS_HOST_DEVICE - float operator()(float const lhs, float const rhs) const { + float operator()(float lhs, float rhs) const { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) float res; asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); return res; #else using CUTLASS_CMATH_NAMESPACE :: isnan; + return lhs > rhs || isnan(lhs) ? lhs : rhs; #endif } @@ -418,20 +413,17 @@ template using maximum_with_nan_propogation = maximum_with_nan_propagation; template -struct minimum{ - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { - return (rhs < lhs ? rhs : lhs); - } -}; - -template -struct minimum { +struct minimum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { - using CUTLASS_CMATH_NAMESPACE :: isnan; + if constexpr (PropagateNaN && cutlass::platform::is_floating_point::value) { + using CUTLASS_CMATH_NAMESPACE :: isnan; - return lhs < rhs || isnan(lhs) ? lhs : rhs; + return lhs < rhs || isnan(lhs) ? lhs : rhs; + } + else { + return (rhs < lhs ? rhs : lhs); + } } }; @@ -443,6 +435,21 @@ struct minimum { } }; +template <> +struct minimum { + CUTLASS_HOST_DEVICE + float operator()(float lhs, float rhs) const { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + float res; + asm volatile("min.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); + return res; +#else + // No need for ADL; call std::isnan(float) on host and ::isnan(float) on device. + return lhs < rhs || (CUTLASS_CMATH_NAMESPACE :: isnan(lhs)) ? lhs : rhs; +#endif + } +}; + template struct minimum_with_nan_propagation : minimum {}; @@ -819,9 +826,9 @@ struct atomic_add void operator()(half2 *ptr, const half2 &data) { #if !defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)) - CUTLASS_UNUSED(ptr); - CUTLASS_UNUSED(data); - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(data); + CUTLASS_NOT_IMPLEMENTED(); #else // Vector-2 atomic reduction requires .target sm_60 or higher uint32_t word = reinterpret_cast(data); @@ -879,7 +886,6 @@ struct is_atomic> : platform::true_type {}; template struct is_atomic> : platform::true_type {}; - ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for nvcuda::wmma::fragment diff --git a/include/cutlass/gemm/collective/builders/sm90_common.inl b/include/cutlass/gemm/collective/builders/sm90_common.inl index 298793e886..8d95967f97 100644 --- a/include/cutlass/gemm/collective/builders/sm90_common.inl +++ b/include/cutlass/gemm/collective/builders/sm90_common.inl @@ -38,6 +38,7 @@ #include "cutlass/detail/dependent_false.hpp" #include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/mma_traits_sm90_gmma_sparse.hpp" #include "cute/atom/copy_traits_sm90_tma.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -123,13 +124,12 @@ sm90_cluster_shape_to_tma_atom(UnimodalClusterShape) { } } -// Generates the most efficient possible TiledCopy with cp.async copy atom given a set of parameters. -template +// Generates the most efficient possible TiledCopy with simt copy atom(e.g. cp.async) given a set of parameters. +template constexpr auto -make_cp_async_gmem_tiled_copy() { +make_simt_gmem_tiled_copy() { using namespace cute; - using AlignmentType = cute::uint_byte_t(sizeof(Element)) * Alignment>; constexpr int TileSizeMN = cute::size(TileMN{}); constexpr int TileSizeK = cute::size(TileK{}); @@ -144,7 +144,7 @@ make_cp_async_gmem_tiled_copy() { static_assert(ThreadCount % threads_major == 0); static_assert(threads_minor == 0 || (TileSizeMN % threads_minor == 0)); return make_tiled_copy( - Copy_Atom, Element>{}, + CopyAtom{}, Layout,Int>, Stride, _1>>{}, Layout>>{}); @@ -157,13 +157,12 @@ make_cp_async_gmem_tiled_copy() { static_assert(ThreadCount % threads_major == 0); static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0)); return make_tiled_copy( - Copy_Atom, Element>{}, + CopyAtom{}, Layout,Int>, Stride< _1,Int>>{}, Layout,_1>>{}); - } - else { - static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); + } else { + static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); } } @@ -319,6 +318,62 @@ ss_smem_selector() } } +// Helper for SS GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +CUTE_HOST_DEVICE constexpr +auto +ss_smem_selector_sparse() +{ + using namespace cute; + + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + + if constexpr (major == GMMA::Major::MN) { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_SpAtom{}) == 0) { + return GMMA::Layout_MN_SW128_SpAtom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_SpAtom{}) == 0) { + return GMMA::Layout_MN_SW64_SpAtom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_SpAtom{}) == 0) { + return GMMA::Layout_MN_SW32_SpAtom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_SpAtom{}) == 0) { + return GMMA::Layout_MN_INTER_SpAtom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_SpAtom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_SpAtom{}) == 0) { + return GMMA::Layout_K_SW128_SpAtom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_SpAtom{}) == 0) { + return GMMA::Layout_K_SW64_SpAtom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_SpAtom{}) == 0) { + return GMMA::Layout_K_SW32_SpAtom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_SpAtom{}) == 0) { + return GMMA::Layout_K_INTER_SpAtom{}; + } + else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_SpAtom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + template constexpr bool is_input_size_two_bytes() { diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 0b3ecb15c6..a4cc768638 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -49,21 +49,21 @@ namespace cutlass::gemm::collective { namespace detail { -// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template constexpr int compute_stage_count_or_override(StageCount stage_count) { return stages; } -// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template constexpr int compute_stage_count_or_override(cute::Int stage_count) { return stages; } -// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template constexpr int compute_stage_count_or_override(StageCountAutoCarveout stage_count) { @@ -78,7 +78,7 @@ compute_stage_count_or_override(StageCountAutoCarveout stage_cou return (CapacityBytes - carveout_bytes) / stage_bytes; } -// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. template constexpr int compute_stage_count_or_override_single_affine_transformed_input(StageCount stage_count) { @@ -86,16 +86,16 @@ compute_stage_count_or_override_single_affine_transformed_input(StageCount -constexpr int get_bits_for_possibly_void_element() { +constexpr int get_bits_for_possibly_void_element() { if constexpr (cute::is_same_v) { return 0; - } + } else { return sizeof_bits::value; } } -// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. template constexpr int compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout stage_count) { @@ -113,7 +113,7 @@ compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCa static_assert(scale_bytes % 128 == 0, "Scale bytes must be a multiple of 128"); static_assert(zero_bytes % 128 == 0, "Zero bytes must be a multiple of 128"); - // When scales are void, s_bits will be 0 so no smem will be allocated for scales. + // When scales are void, s_bits will be 0 so no smem will be allocated for scales. constexpr int stage_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + @@ -140,7 +140,7 @@ is_warpspecialized_transpose_B(){ cutlass::gemm::detail::is_mn_major_B(); constexpr bool IsWarpSpecialized = cute::is_base_of_v || cute::is_base_of_v || - cute::is_base_of_v || + cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v; @@ -240,8 +240,8 @@ struct CollectiveBuilder< MainloopSm90TmaGmmaWarpSpecializedFP8, MainloopSm90TmaGmmaWarpSpecialized>>; - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; using CollectiveOp = CollectiveMma< DispatchPolicy, @@ -296,7 +296,7 @@ struct CollectiveBuilder< (cute::is_same_v || cute::is_same_v || cute::is_same_v) && - detail::is_use_rmem_A()> + detail::is_use_rmem_A()> > { static_assert(is_static::value); static_assert(is_static::value); @@ -335,8 +335,8 @@ struct CollectiveBuilder< using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecialized< PipelineStages, ClusterShape_MNK, KernelScheduleType>; - using SmemCopyAtomA = cute::conditional_t>; - using SmemCopyAtomB = cute::conditional_t, void>; + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; using CollectiveOp = CollectiveMma< DispatchPolicy, @@ -404,7 +404,7 @@ public: using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>; using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>; static_assert(cute::is_tuple::value ^ cute::is_tuple::value || - (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), + (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), "Either A OR B must be a tuple or the widths of A and B must be different."); static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; @@ -458,8 +458,8 @@ public: static constexpr int PipelineStages = detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}); - using SmemCopyAtomA = cute::conditional_t>; - using SmemCopyAtomB = cute::conditional_t, void>; + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; @@ -794,11 +794,16 @@ struct CollectiveBuilder< static constexpr int NumLoadWarpGroups = cute::is_same_v ? 2 : 1; - using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< - NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementA, AlignmentA, TagToStrideA_t, + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomA, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< - NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementB, AlignmentB, TagToStrideB_t, + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutAtomA = decltype(detail::ss_smem_selector< @@ -895,14 +900,19 @@ struct CollectiveBuilder< static constexpr int NumLoadWarpGroups = 1; - using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< - NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementA, AlignmentA, TagToStrideA_t, + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomA, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< - NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementB, AlignmentB, TagToStrideB_t, + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); using SmemLayoutAtomB = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); @@ -913,8 +923,8 @@ struct CollectiveBuilder< using DispatchPolicy = MainloopSm90CpAsyncGmmaRmemAWarpSpecialized< PipelineStages, ClusterShape_MNK, KernelScheduleType>; - using SmemCopyAtomA = cute::conditional_t>; - using SmemCopyAtomB = cute::conditional_t, void>; + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; using CollectiveOp = CollectiveMma< DispatchPolicy, @@ -1025,3 +1035,4 @@ static constexpr bool IsMixedWidthInput = IsDifferentWidth || (IsDifferentWidth } // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl b/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl new file mode 100644 index 0000000000..f9aa7bab2d --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl @@ -0,0 +1,268 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Sparse configs specific for SM90 structure sparse kernels +*/ + + +#pragma once + +#include "cute/atom/mma_traits_sm90_gmma.hpp" // cute::GMMA::Major +#include "cute/layout.hpp" // cute::Layout, cute::Shape, cute::Stride +#include "cute/numeric/integral_constant.hpp" // cute::Int +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cute/pointer_sparse.hpp" // cute::is_sparse +#include "cute/util/type_traits.hpp" // cute::is_same_v, cute::conditional_t +#include "cutlass/fast_math.h" // cutlass::round_up +#include "cutlass/layout/matrix.h" // cutlass::RowMajor, cutlass::ColumnMajor + +namespace cutlass { + +using namespace cute; + +template< + class ElementAMma_, + GMMA::Major GmmaMajorA, + class ElementEMma_, + class MinTileShapeK = Int<32> +> +struct Sm90GemmSparseConfig { + + static_assert(cute::is_sparse::value, "ElementAMma MUST be sparse elem"); + static_assert(cute::is_sparse::value, "ElementEMma MUST be sparse elem"); + + // A + using ElementAMma = ElementAMma_; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using ElementAMmaSparsity = Int; + + // Metadata (E) + using ElementEMma = ElementEMma_; + using ElementEMmaRaw = typename ElementEMma::raw_type; + using ElementEMmaSparsity = Int; + + // MMA type + static constexpr bool IsQmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + cute::is_same_v && ElementAMmaSparsity{} == _2{}; + static constexpr bool IsImma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + cute::is_same_v && ElementAMmaSparsity{} == _2{}; + static constexpr bool IsHmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + cute::is_same_v && ElementAMmaSparsity{} == _2{}; + static constexpr bool IsTfmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + cute::is_same_v && ElementAMmaSparsity{} == _2{}; + static_assert(int(IsQmma) + int(IsImma) + int(IsHmma) + int(IsTfmma) == 1, "Ambigious Input Type Config (failed to choose MMA type)"); + + // Number of ElementARaw stored in ElementAMmaRaw. For Hopper this is always 1. + using ElemsARawPerElementAMmaRaw = _1; + + // ElementA Sparsity Ratio + using ElementASparsity = ElementAMmaSparsity; + static_assert(ElementASparsity{} == _2{}, "ElementASparsity must be 2 for Hopper Sparse Gemm"); + + // Logical/Physical ElementA per Chunk + using LogicalElemsAPerChunk = conditional_t; + using PhysicalElemsAPerChunk = Int; + + // Metadata Bits + using ElementEBitsPerChunk = _4; + using ElementEBitsPerElementAMma = cute::conditional_t; + + // Metadata Layout. Unit in corresbonding logical elements. + // Basic metadata block is (16,64) for 8-bit, (16,32) for 16-bit, (16,16) for 32-bit data types. + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-wgmma-metadata-64n32-f16bf16 + // Tensor E layout atom stacks 4 basic blocks along M mode to align with WGMMA instruction shape and + // stacks 1-4 blocks along K mode and reorders memory layout to allow for vectorized loads from smem. + using BlockK = Int<512 / sizeof_bits_v>; + static_assert(MinTileShapeK{} % BlockK{} == 0, "MinTileShapeK must be a multiple of BlockK"); + using NumK = decltype(MinTileShapeK{} / BlockK{}); + + using TensorEAtom_32bit = decltype(make_ordered_layout(Shape, Shape<_8,_2,NumK>>{}, + Step , Step <_0,_4, _2>>{})); + + using TensorEAtom_16bit = decltype(make_ordered_layout(Shape, Shape<_16,_2,NumK>>{}, + Step , Step < _0,_4, _2>>{})); + + using TensorEAtom_8bit = decltype(make_ordered_layout(Shape<_64,MinTileShapeK>{}, + Step < _1, _0>{})); + + using TensorEAtom = cute::conditional_t<(IsQmma || IsImma), TensorEAtom_8bit, + cute::conditional_t>; + + // Logical elems that construct the atomK for tensorE/A. + using TensorEAtomK = Int(TensorEAtom{})>; + using TensorEAtomM = Int(TensorEAtom{})>; + + // Tensor E alignment requirements + using TensorEAlignmentM = TensorEAtomM; + using TensorEAlignmentK = TensorEAtomK; + + // Tensor A alignment requirements + // When A is MN major, TensorAAlignmentK needs to be multiplier of chunk size + // When A is K major, TensorAAlignmentK needs to be multiplier of TMA requirements times tensorA sparsity + // this is b.c. TensorACompressed needs to satisfy TMA requirements + using TensorAAlignmentK = cute::conditional_t>>; + + // When A is MN Major, TensorAAlignmentM needs to be multiplier of TMA requirements + // When A is K Major, no requirements on TensorAAlignmentM. + using TensorAAlignmentM = cute::conditional_t * ElemsARawPerElementAMmaRaw{}>, + _1>; + + // The following two functions are provided for user determine the static layouts type + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutA() { + using LayoutMMajor = Layout, + int32_t>, + Stride, + int64_t>>; + + using LayoutKMajor = Layout, + int32_t>, + Stride, + int64_t>>; + + if constexpr (GmmaMajorA == GMMA::Major::MN) { + return LayoutMMajor{}; + } + else { + return LayoutKMajor{}; + } + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutE() { + return make_layout( + make_shape(make_shape(shape<0>(TensorEAtom{}), int32_t(0)), + make_shape(shape<1>(TensorEAtom{}), int32_t(0)), + int32_t(0)), + make_stride(make_stride(stride<0>(TensorEAtom{}), cute::Int{}), + make_stride(stride<1>(TensorEAtom{}), int64_t(0)), + int64_t(0)) + ); + } + + // This function is used to revert a CuTe layout to a Cutlass layout tag (RowMajor/ColumnMajor) + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutA_tag(Layout layout_a) { + /* + (m, (2, k/2), l) : (2, (1, m*2), m*k) M-major + (m, (2, k/2), l) : (k, (1, 2), m*k) K-major + */ + // Check if the given layout_a is possibly a sparse tensorA layout. + static_assert(rank_v == 3 && depth_v == 2, "Rank and depth mismatch with the sparse tensorA's layout."); + static_assert(rank(get<1>(ShapeA{})) == 2 && rank(flatten(ShapeA{})) == 4, + "Not likely to be a sparse tensorA's layout."); + static_assert(get<1,0>(StrideA{}) == 1 && get<1,0>(ShapeA{}) == ElementASparsity{}, + "Not likely to be a sparse tensorA's layout."); + static_assert(get<0>(StrideA{}) == ElementASparsity{} || get<1,1>(StrideA{}) == ElementASparsity{}, + "Not likely to be a sparse tensorA's layout."); + + if constexpr (get<0>(StrideA{}) == ElementASparsity{}) { + return cutlass::layout::ColumnMajor{}; + } + else { + return cutlass::layout::RowMajor{}; + } + } + + // Fill tensor A layout from dynamic problem shape + template + CUTE_HOST_DEVICE + static constexpr auto + fill_layoutA(ProblemShape problem_shape) { + + const auto [M, N, K, L] = problem_shape; + + // Round up to satisfy TensorA Alignment requirement + const auto M_AlignedAC = cutlass::round_up(M, TensorAAlignmentM{}); + const auto K_AlignedAC = cutlass::round_up(K, TensorAAlignmentK{}); + + if constexpr (GmmaMajorA == GMMA::Major::MN) { + return make_layout( + make_shape(int32_t(M_AlignedAC), + make_shape(ElementASparsity{}, int32_t(K_AlignedAC) / ElementASparsity{}), + int32_t(L)), + make_stride(ElementASparsity{}, + make_stride(_1{}, int64_t(M_AlignedAC) * ElementASparsity{}), + (L == 1) ? int64_t(0) : int64_t(M_AlignedAC * K_AlignedAC)) + ); + } + else { + return make_layout( + make_shape(int32_t(M_AlignedAC), + make_shape(ElementASparsity{}, int32_t(K_AlignedAC / ElementASparsity{})), + int32_t(L)), + make_stride(int64_t(K_AlignedAC), + make_stride(_1{}, ElementASparsity{}), + (L == 1) ? int64_t(0) : int64_t(M_AlignedAC * K_AlignedAC)) + ); + } + } + + // Fill tensor E layout from dynamic problem shape + template + CUTE_HOST_DEVICE + static constexpr auto + fill_layoutE(ProblemShape problem_shape) { + const auto [M, N, K, L] = problem_shape; + + // Round up to satisfy TensorEAlignment requirement + const auto M_AlignedE = cutlass::round_up(M, TensorEAlignmentM{}); + const auto K_AlignedE = cutlass::round_up(K, TensorEAlignmentK{}); + + // TensorEAtom first along m-dim, then along k-dim, then along batch + static_assert(TensorEAlignmentM{} == TensorEAtomM{}, "Shape below assumes TensorEAlignmentM == TensorEAtomM"); + static_assert(TensorEAlignmentK{} == TensorEAtomK{}, "Shape below assumes TensorEAlignmentK == TensorEAtomK"); + + return make_layout( + make_shape(make_shape(shape<0>(TensorEAtom{}), int32_t(M_AlignedE / TensorEAtomM{})), + make_shape(shape<1>(TensorEAtom{}), int32_t(K_AlignedE / TensorEAtomK{})), + int32_t(L)), + make_stride(make_stride(stride<0>(TensorEAtom{}), cute::Int{}), + make_stride(stride<1>(TensorEAtom{}), int64_t(M_AlignedE * TensorEAtomK{})), + (L == 1) ? int64_t(0) : int64_t(M_AlignedE * K_AlignedE)) + ); + } +}; + +} // namespace cutlass diff --git a/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl new file mode 100644 index 0000000000..9b608fe022 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl @@ -0,0 +1,388 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_sparse(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_sparse(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_sparse(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto e_bits = cute::sizeof_bits_v; + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(e_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS_SPARSE +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + not detail::is_use_rmem_A()> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input, "FP8 sparse collective currently only supports FastAccum schedules"); + + // For fp32 types, map to tf32 MMA value type + using ElementAMmaRaw = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector_sparse< + ElementAMmaRaw, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaSparsity = Int; + using ElementEMma = typename TiledMma::ValTypeE; + using SparseConfig = cutlass::Sm90GemmSparseConfig(TileShape_MNK{}),_128{}))>; + + using LayoutA = decltype(SparseConfig::deduce_layoutA()); + using LayoutE = decltype(SparseConfig::deduce_layoutE()); + using LayoutPairAE = decltype(cute::make_tuple(LayoutA{}, LayoutE{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector_sparse< + GmmaMajorA, ElementAMmaRaw, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), ElementAMmaSparsity>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_sparse(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedSparse; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + LayoutPairAE, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_SS_FP8_FAST_ACCUM_SPARSE +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v)> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector_sparse< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaSparsity = Int; + using ElementEMma = typename TiledMma::ValTypeE; + using SparseConfig = cutlass::Sm90GemmSparseConfig(TileShape_MNK{}),_128{}))>; + + using LayoutA = decltype(SparseConfig::deduce_layoutA()); + using LayoutE = decltype(SparseConfig::deduce_layoutE()); + using LayoutPairAE = decltype(cute::make_tuple(LayoutA{}, LayoutE{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector_sparse< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), ElementAMmaSparsity>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_sparse(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedSparse; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + LayoutPairAE, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_RS_SPARSE +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + detail::is_use_rmem_A()> +> { + static_assert(cutlass::detail::dependent_false, "Mainloop with sparse A sourced from RF is not implemented."); +}; + +// Sparse GMMA auto kernel schedule +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr bool IsFP8Input = detail::is_input_fp8(); + + using KernelSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + cute::conditional_t, + cute::conditional_t>; + + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassSparseTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelSchedule + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index 532bfecfb1..ccd8d8b3c7 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -38,4 +38,5 @@ #include "cutlass/gemm/collective/collective_builder_decl.hpp" #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" +#include "cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 7bcc075782..103da9af7b 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -43,6 +43,7 @@ #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 75d7bb39e9..9825a16571 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -166,12 +166,12 @@ struct CollectiveMma< size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; - struct TensorMapStorage : cute::aligned_struct<128> { + struct TensorMapStorage : cute::aligned_struct<128, _0> { cute::TmaDescriptor smem_tensormap_A; cute::TmaDescriptor smem_tensormap_B; } tensormaps; @@ -720,7 +720,6 @@ struct CollectiveMma< ProblemShape_MNKL problem_shape_mnkl, int32_t next_batch) { if (cute::elect_one_sync()) { - // Replacing global_address for the next batch tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp index 4b291db358..69b31fdabe 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp @@ -187,7 +187,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<256> { + struct TensorStorage : cute::aligned_struct<256, _0> { cute::array_aligned, 256> smem_A; cute::array_aligned, 256> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp index 90e7acd38c..e336bd4755 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp @@ -135,7 +135,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index 43e05afa07..b30fed1c85 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -213,7 +213,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct { + struct TensorStorage : cute::aligned_struct { cute::array_aligned, SmemAlignmentA> smem_A; cute::array_aligned, SmemAlignmentB> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 1f679c88ca..8c98d15c29 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -174,7 +174,7 @@ struct CollectiveMma< using SmemCopyAtomA = SmemCopyAtomA_; using SmemCopyAtomB = SmemCopyAtomB_; - using SmemCopyAtomScale = Copy_Atom; + using SmemCopyAtomScale = Copy_Atom; // We must ensure the type to be scaled goes to RF static constexpr bool SwapAB = !IsATransformed; @@ -202,6 +202,7 @@ struct CollectiveMma< static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization using ArchTag = typename DispatchPolicy::ArchTag; @@ -273,6 +274,8 @@ struct CollectiveMma< static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; static constexpr auto elements_per_smem_scale() { @@ -304,22 +307,30 @@ struct CollectiveMma< // These methods use some the public members of the class. For that reason, we define them after the public section. static constexpr uint32_t compute_tma_transaction_bytes_mk() { - constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + return cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + static constexpr uint32_t + compute_tma_transaction_bytes_extra() { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return baseline_bytes; + return 0; } else if constexpr (ModeHasScales) { constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - return baseline_bytes + scale_tx_bytes; + return scale_tx_bytes; } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { // Scale and zero share smem layout constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA - return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + return scale_tx_bytes + zero_tx_bytes; } else { static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); @@ -330,11 +341,6 @@ struct CollectiveMma< } } - static constexpr uint32_t - compute_tma_transaction_bytes_nk() { - return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); - } - public: static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); @@ -349,7 +355,7 @@ struct CollectiveMma< { static constexpr int scale_elements = elements_per_smem_scale(); static constexpr int zero_elements = elements_per_smem_zero(); - struct TensorStorage : cute::aligned_struct { + struct TensorStorage : cute::aligned_struct { cute::ArrayEngine> smem_A; cute::ArrayEngine> smem_B; cute::ArrayEngine smem_scale; @@ -389,14 +395,14 @@ struct CollectiveMma< public: // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy( + using TMA_A = decltype(make_tma_copy_A_sm90( GmemTiledCopyA{}, make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + TileShape{}, + ClusterShape{})); // mcast along N mode for this M load, if any - using TMA_Scale = decltype(make_tma_copy( + using TMA_Scale = decltype(make_tma_copy( GmemTiledCopyScale{}, make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), SmemLayoutScale{}(_,_,cute::Int<0>{}), @@ -411,12 +417,12 @@ struct CollectiveMma< _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy( + using TMA_B = decltype(make_tma_copy_B_sm90( GmemTiledCopyB{}, make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{})); // mcast along M mode for this N load, if any TMA_A tma_load_a; TMA_B tma_load_b; TMA_Scale tma_load_scale; @@ -424,8 +430,7 @@ struct CollectiveMma< int64_t scale_k; int group_size; uint32_t tma_transaction_bytes = TmaTransactionBytes; - uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; - uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + int reload_factor = (group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); }; // @@ -466,31 +471,33 @@ struct CollectiveMma< Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M,K,L), dA)); Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N,K,L), dB)); - typename Params::TMA_A tma_load_a = make_tma_copy( + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + TileShape{}, + ClusterShape{}); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy( + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{}); // mcast along M mode for this N load, if any - typename Params::TMA_Scale tma_load_scale; - typename Params::TMA_Zero tma_load_zero; + typename Params::TMA_Scale tma_load_scale{}; + typename Params::TMA_Zero tma_load_zero{}; + + uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1 }; } else if constexpr (ModeHasScales) { auto scale_k = (K + args.group_size - 1) / args.group_size; ElementScale const* ptr_S = args.ptr_S; StrideScale dS = args.dS; Tensor tensor_scale = make_tensor(get_logical_ptr(ptr_S), make_layout(make_shape(M,scale_k,L), dS)); - tma_load_scale = make_tma_copy( + tma_load_scale = make_tma_copy( GmemTiledCopyScale{}, tensor_scale, SmemLayoutScale{}(_,_,cute::Int<0>{}), @@ -498,7 +505,7 @@ struct CollectiveMma< _1{}); // mcast along N mode for this M load, if any if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) }; } else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS)); @@ -508,7 +515,7 @@ struct CollectiveMma< SmemLayoutScale{}(_,_,cute::Int<0>{}), ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) }; } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); } @@ -571,7 +578,8 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); - static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + static constexpr uint32_t TmaTransactionBytesExtra = compute_tma_transaction_bytes_extra(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesExtra; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -674,122 +682,117 @@ struct CollectiveMma< int lane_predicate = cute::elect_one_sync(); - if (lane_predicate) { - Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) - Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) - // - // Prepare the TMA loads for A, B and Scales - // - - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - uint16_t mcast_mask_s = 0; + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; - // Issue TmaLoads - // Maps the tile -> block, value - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); - } + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); } + } - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); - } + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); } + } - auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); - // - // Copy gmem to smem for *k_tile_iter - // + // + // Copy gmem to smem for *k_tile_iter + // - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - int write_stage = smem_pipe_write.index(); - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // Nothing extra to do. - } - else if constexpr (ModeHasScales) { - auto tSgS = get<0>(extra_input_partitions); - auto tSsS = get<1>(extra_input_partitions); - - // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes - // on the fly. - // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K - // is a multiple of the threadblock tile K - const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); - const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. - copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); - - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - // Nothing extra to do - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - auto tZgZ = get<2>(extra_input_partitions); - auto tZsZ = get<3>(extra_input_partitions); - copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); - } - else { - static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); - } + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + int const scale_load_k = *k_tile_iter / mainloop_params.reload_factor; // This will always be 0 when group_size == K. + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); - } + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } - ++k_tile_iter; + ++k_tile_iter; - // Advance smem_pipe_write - ++smem_pipe_write; - } + // Advance smem_pipe_write + ++smem_pipe_write; } } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { - int lane_predicate = cute::elect_one_sync(); - // Issue the epilogue waits - if (lane_predicate) { + if (cute::elect_one_sync()) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used @@ -868,13 +871,6 @@ struct CollectiveMma< Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) - // Compute the max vector length that can be used to copy A. This will match the vector width of the - // conversions used. It helps by allowing the compiler to convert using the same register that was used - // to load the data from smem. This significantly reduces the need to move data among registers. - // Note that this is correct even if copy fails to vectorize, since the granularity at which we perform - // the conversion does not impact correctness. - using A_CPY_VEC = decltype(max_common_vector(tCsA, tCrA_copy_view)); - // Partition of thread -> shared and thread -> RF auto partitioned_extra_info = partition_extra_mma_info(mma_thread_slice, shared_tensors); auto copy_partitions_extra_info = retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); @@ -915,16 +911,21 @@ struct CollectiveMma< // copy smem->rmem for A operand copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); - - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); + } + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { - if (k_block < K_BLOCK_MAX - 1) { + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, - partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) @@ -936,11 +937,15 @@ struct CollectiveMma< --k_tile_count; if (k_tile_count > 0) { // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. - warpgroup_wait(); pipeline.consumer_wait(smem_pipe_read, barrier_token); copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + } + warpgroup_wait(); + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); } } @@ -971,9 +976,8 @@ struct CollectiveMma< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); if (k_block == K_BLOCK_MAX - 1) { - // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it ++smem_pipe_release; } @@ -986,12 +990,18 @@ struct CollectiveMma< pipeline.consumer_wait(smem_pipe_read, barrier_token); copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + } + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); } else { - copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, - partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } } warpgroup_fence_operand(accum); @@ -1018,17 +1028,20 @@ struct CollectiveMma< cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); - if (k_block == K_BLOCK_MAX - 1) { - // release prior barrier + if (k_block == K_BLOCK_MAX - 1) { // release prior barrier + warpgroup_wait(); pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it ++smem_pipe_release; } + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } if (k_block < K_BLOCK_MAX - 1) { copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); - transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); + transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } } } @@ -1110,10 +1123,20 @@ struct CollectiveMma< // nothing to do return cute::make_tuple(); } + else if constexpr (UseScaleLookupTable) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS_neg, tCrS_pos); + } + } else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = mma_thread_slice.partition_A(sS); - Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); @@ -1121,7 +1144,7 @@ struct CollectiveMma< else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsZ = mma_thread_slice.partition_A(sZ); - Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); } else { @@ -1210,159 +1233,293 @@ struct CollectiveMma< } } } + + // Helper functions to select packing for conversion + template + struct select_packing { // Naive packing policy + static constexpr auto value() { + return Int, sizeof_bits_v))>{}; + } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + template struct select_packing { + static constexpr auto value() { return Int{}; } + }; + + CUTLASS_DEVICE + static uint32_t to_reg(Array const& source) { + return static_cast( + reinterpret_cast(source)); + } + CUTLASS_DEVICE + static uint32_t to_reg(Array const& source) { + return reinterpret_cast(source); + } + // The core converter uses a lookup table to converts i4 -> 8 bit value. + template + CUTLASS_DEVICE + static Array lookup_table_convert( + cute::Int _, + Array const& source, + TensorPos const& scale_neg, + TensorNeg const& scale_pos, + int scale_idx) { + + static_assert(N == 4 || N == 8); + uint32_t res[N / 4]; + + // View the input as reg + uint32_t reg = to_reg(source); + + // Determines if to get from the signed or unsigned candidates + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" \ + "}\n" + : "=r"(sign) + : "r"(reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut) + ); + sign = sign >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = reg & 0x77777777; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i, lut_idx >>=16, sign >>=16) { + Array const& _scale_neg = reinterpret_cast const&>(scale_neg[scale_idx + i * 4]); + Array const& _scale_pos = reinterpret_cast const&>(scale_pos[scale_idx + i * 4]); + asm volatile( + "{\n" + " .reg .b32 pos, neg ;\n" \ + " prmt .b32 neg, %3, %4, %1 ;\n" \ + " prmt .b32 pos, %5, %6, %1 ;\n" \ + " prmt .b32 %0, pos, neg, %2 ;\n" \ + "}\n" + : "=r"(res[i]) + : "r"(lut_idx), "r"(sign), "r"(_scale_neg[0]), "r"(_scale_neg[1]), "r"(_scale_pos[0]), "r"(_scale_pos[1]) + ); + } + return reinterpret_cast&>(res); + } + + template + CUTLASS_DEVICE + static void static_check_scale(Layout const& tensor) { + static_assert(shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0, "At least 4 adjacent weights in a thread must share the same scale."); + } + template + CUTLASS_DEVICE + static void static_check_scale(Tensor const& tensor) { + static_check_scale(flatten(Layout{})); + } /// Utilities to transform A. - template CUTLASS_DEVICE void transform_A_kblock( - TCrA_load const& tCrA_load, - cute::Int vec_A, - TCrA_mma& tCrA_mma, + Tensor const& tCrA_load, + Tensor& tCrA_mma, cute::tuple const& partitioned_extra_info, int const k_block) { + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + auto const& src = tCrA_load(_, _, k_block); + auto const& dst = tCrA_mma(_, _, k_block); + auto pSrc = raw_pointer_cast(src.data()); + auto pDst = const_cast(raw_pointer_cast(dst.data())); + constexpr int num_elements = decltype(size(src))::value; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - transform_internal_A(tCrA_load(_, _, k_block), vec_A, tCrA_mma(_, _, k_block)); + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + *pDstArr = Converter::convert(*pSrcArr); + } } + else if constexpr (UseScaleLookupTable) { + static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); + static_assert(sizeof_bits_v == 64, "Lookup table only supports 8 8bit scale values now."); + static_assert(num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting."); + constexpr int pack = num_elements % 8 == 0? 8 : 4; + constexpr int iters = num_elements / pack; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + auto const& tCrS_neg = cute::get<1>(partitioned_extra_info); + auto const& tCrS_pos = cute::get<2>(partitioned_extra_info); + auto const& scale_neg = tCrS_neg(_, _, k_block); + auto const& scale_pos = tCrS_pos(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scale_neg)); + + static_check_scale(scale_neg); + static_check_scale(scale_pos); + if (k_block == 0) { + auto pNeg = raw_pointer_cast(tCrS_neg.data()); + auto pPos = const_cast(raw_pointer_cast(tCrS_pos.data())); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cosize(tCrS_neg.layout()); ++i) + { + // pPos[i] = pNeg[i] & 0x7F7F7F7F7F7F7F00; + cutlass::Array const& _scale_neg = reinterpret_cast const&>(pNeg[i]); + cutlass::Array & _scale_pos = reinterpret_cast &>(pPos[i]); + asm volatile( + "{\n" + " and .b32 %0, %2, %4 ;\n" \ + " and .b32 %1, %3, %5 ;\n" \ + "}\n" + : "=r"(_scale_pos[0]), "=r"(_scale_pos[1]) + : "r"(_scale_neg[0]), "r"(_scale_neg[1]), "n"(0x7F7F7F00), "n"(0x7F7F7F7F) + ); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; i ++) { + SrcArray const* pSrcArr = reinterpret_cast(raw_pointer_cast(src.data())) + i; + DstArray* pDstArr = reinterpret_cast(raw_pointer_cast(dst.data())) + i; + + *pDstArr = lookup_table_convert(Int{}, *pSrcArr, scale_neg, scale_pos, i * pack); + } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { - auto tCrS = cute::get<1>(partitioned_extra_info); - transform_internal_A(tCrA_load(_, _, k_block), vec_A, make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrA_mma(_, _, k_block)); + auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + + if constexpr (is_same_v) { + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + *pDstArr = Converter::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) { + (*pDstArr)[j] = (*pDstArr)[j] * scales[i*pack + j]; + } + } + } + else { + constexpr int pack1 = decltype(select_packing::value())::value; + constexpr int pack2 = decltype(select_packing::value())::value; + constexpr int pack = cute::gcd(pack1, pack2); + using Converter1 = cutlass::NumericArrayConverter; + using Converter2 = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using StageArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + StageArray stageArr; + stageArr = Converter1::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) { + stageArr[j] = stageArr[j] * scales[i*pack + j]; + } + *pDstArr = Converter2::convert(stageArr); + } + } } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - auto tCrS = cute::get<1>(partitioned_extra_info); - auto tCrZ = cute::get<3>(partitioned_extra_info); - transform_internal_A(tCrA_load(_, _, k_block), - vec_A, - make_fragment_like(tCrA_mma)(_, _, k_block), - tCrS(_, _, 0), - tCrZ(_, _, 0), - make_fragment_like(tCrZ)(_, _, 0), - tCrA_mma(_, _, k_block)); + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + + if constexpr (is_same_v) { + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + *pDstArr = Converter::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) { + (*pDstArr)[j] = (*pDstArr)[j] * scales[i*pack + j] + zeros[i*pack + j]; + } + } + } + else { + constexpr int pack1 = decltype(select_packing::value())::value; + constexpr int pack2 = decltype(select_packing::value())::value; + constexpr int pack = cute::gcd(pack1, pack2); + using Converter1 = cutlass::NumericArrayConverter; + using Converter2 = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using StageArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + StageArray stageArr; + stageArr = Converter1::convert(*pSrcArr); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < pack; ++j) { + stageArr[j] = stageArr[j] * scales[i*pack + j] + zeros[i*pack + j]; + } + *pDstArr = Converter2::convert(stageArr); + } + } + return; } else { static_assert(cutlass::detail::dependent_false, "No A data is loaded."); } } - - /// Utilities for transforming the A operand prior to issuing tensorcore math. - template > - CUTLASS_DEVICE void - convert_tensor( - Tensor const& in, - Tensor& out, - cute::Int width = {}) { - - /// This is an element-wise conversion where we expect both tensors to have the same layout. - /// As a result, we can cast as a cutlass array to use the fast numeric converters without - /// worrying about indexing into the layout. - constexpr int N = cosize_v; - - /// The inputs must be backed by registers & be statically sized. - static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); - static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); - static_assert(is_static_v, "Tensor layout for the conversion must be static"); - static_assert(cosize_v == size(TensorLayout{}), "Cosize and size of the layout must be equal."); - static_assert(N % ConversionVectorWidth == 0, "Conversion vector width must divide cosize of the tensor layout."); - - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - using Converter = cutlass::NumericArrayConverter; - - constexpr int NumIterations = N / ConversionVectorWidth; - - for (int ii = 0; ii < NumIterations; ++ii) { - SrcArray const* src_array_ptr = reinterpret_cast(raw_pointer_cast(in.data())) + ii; - DstArray* dst_array_ptr = reinterpret_cast(raw_pointer_cast(out.data())) + ii; - *dst_array_ptr = Converter::convert(*src_array_ptr); - } - } - - template - CUTLASS_DEVICE void - transform_internal_A( - Tensor&& in, - cute::Int a_vec_width, - Tensor&& out) { - - convert_tensor(in, out, a_vec_width); - } - - template - CUTLASS_DEVICE void - transform_internal_A( - Tensor&& in, - cute::Int a_vec_width, - Tensor&& converted_inputs, - Tensor&& scales, - Tensor&& out) { - - static_assert(cute::is_same_v, - "Type of the engine input buffer must equal the scale buffer"); - - // First, we upcast the inputs to the scale type - convert_tensor(in, converted_inputs, a_vec_width); - - // Apply scales and broadcast across inputs, store in converted_inputs - cute::transform(converted_inputs, scales, converted_inputs, cute::multiplies{}); - - // Finally, we convert the scaled inputs to the mma type. - convert_tensor(converted_inputs, out); - } - - template - CUTLASS_DEVICE void - transform_internal_A( - Tensor&& in, - cute::Int a_vec_width, - Tensor&& converted_inputs, - Tensor&& scales, - Tensor&& zeros, - Tensor&& converted_zeros, - Tensor&& out) { - - static_assert(cute::is_same_v, - "Type of the engine input buffer must equal the scale buffer"); - - static_assert(cute::is_same_v, - "Type of the engine zero buffer must equal the scale buffer"); - - // First, we upcast the inputs to the scale type - convert_tensor(in, converted_inputs, a_vec_width); - convert_tensor(zeros, converted_zeros); - - // Apply scales and broadcast across inputs, store in converted_inputs - cute::transform(converted_inputs, scales, converted_inputs, cute::multiplies{}); - cute::transform(converted_inputs, converted_zeros, converted_inputs, cute::plus{}); - - // Finally, we convert the scaled inputs to the mma type. - convert_tensor(converted_inputs, out); - } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index 24af314d5f..b370dc70b5 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -150,7 +150,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index 6c02979996..da5274469f 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -144,7 +144,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _0> { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; diff --git a/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000..01e83bdf54 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,724 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class LayoutPairAE_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedSparse, + TileShape_, + ElementA_, + LayoutPairAE_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedSparse; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using LayoutPairAE = LayoutPairAE_; + using LayoutA = remove_cvref_t(LayoutPairAE{}))>; + using LayoutE = remove_cvref_t(LayoutPairAE{}))>; + using StrideA = decltype(cute::stride(LayoutA{})); + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementE = typename ElementEMma::raw_type; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + + static_assert(is_sparse::value, "ElementAMma is sparse"); + static_assert(!is_sparse::value, "ElementA is not sparse"); + + static constexpr int ElementAMmaSparsity = ElementAMma::sparsity; + static constexpr int ElementEMmaSparsity = ElementEMma::sparsity; + + // LayoutA is nested in the stride due to the sparsity. + static constexpr bool is_A_mn_major = cute::is_same_v(LayoutA{}.stride())), Int>; + static constexpr bool is_B_mn_major = cutlass::gemm::detail::is_major<0,StrideB>(); + + using SparseConfig = cutlass::Sm90GemmSparseConfig(TileShape{}),_128{}))>; + + // The offline permutation for the metadata. + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + // Metadata pathways + using SmemCopyAtomE = AutoVectorizingCopy; + using GmemCopyAtomE = GmemTiledCopyA; + + using CtaShape_MNK = TileShape; + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M,K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (N,K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::sparse_elem, + cutlass::tfloat32_t, + uint_bit_t>>>; + using TmaInternalElementB = cute::conditional_t, + tfloat32_t, + uint_bit_t>>; + + struct SharedStorage + { + struct TensorStorage { + alignas(128) cute::ArrayEngine> smem_A; + alignas(128) cute::ArrayEngine> smem_B; + alignas(128) cute::ArrayEngine> smem_E; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 0; + + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutE{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutB{})) * cute::sizeof_bits_v); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{}; + LayoutA layout_a{}; + ElementB const* ptr_B{}; + StrideB dB{}; + ElementE const* ptr_E{}; + LayoutE layout_e{}; + }; + + // Device side kernel params + struct Params { + + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), LayoutA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + + using TMA_E = decltype(make_tma_copy( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + make_tensor(recast_ptr>(nullptr), LayoutE{}), + SmemLayoutE{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + TMA_A tma_load_a; + TMA_E tma_load_e; + TMA_B tma_load_b; + LayoutA layout_a; + LayoutE layout_e; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + auto ptr_E = recast_ptr>(args.ptr_E); + + Tensor tensor_a = make_tensor(ptr_A, args.layout_a); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensor_e = make_tensor(ptr_E, args.layout_e); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + + typename Params::TMA_E tma_load_e = make_tma_copy( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + tensor_e, + SmemLayoutE{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + return { + tma_load_a, + tma_load_e, + tma_load_b, + args.layout_a, + args.layout_e + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool size_check = true; + // Check Alignment A + if constexpr (is_A_mn_major) { + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(M,K/2,L), cute::make_stride(_1{}, M, M*K/2)); + } + else { // If A is K-major + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(M,K/2,L), cute::make_stride(K/2, _1{}, M*K/2)); + } + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!size_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + // Check if layout_a and layout_e is filled correctly + auto layout_a_ref = SparseConfig::fill_layoutA(problem_shape_MNKL); + auto layout_e_ref = SparseConfig::fill_layoutE(problem_shape_MNKL); + bool layout_check = true; + layout_check = layout_check && (layout_a_ref == args.layout_a); + layout_check = layout_check && (layout_e_ref == args.layout_e); + + if (!layout_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Layout_a/e mismatch.\n"); + } + + return size_check && layout_check; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_e.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(mainloop_params.layout_a.shape()); // (m,k,l) + Tensor mE_mkl = mainloop_params.tma_load_e.get_tma_tensor(mainloop_params.layout_e.shape()); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gE_mkl = local_tile(mE_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, gE_mkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, class TensorE, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + auto [gA_mkl, gB_nkl, gE_mkl] = load_inputs; + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + auto cta_coord_mnk = cta_layout_mnk.get_flat_coord(block_rank_in_cluster); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_e = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<0>(cta_layout_mnk, cta_coord_mnk); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(get<1>(cta_coord_mnk)); + auto block_tma_e = mainloop_params.tma_load_e.get_slice(get<1>(cta_coord_mnk)); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(get<0>(cta_coord_mnk)); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gE = gE_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tEgE = block_tma_e.partition_S(gE); // (TMA,TMA_M,TMA_K,k) + Tensor tEsE = block_tma_e.partition_D(sE); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_e.with(*tma_barrier, mcast_mask_e), tEgE(_,_,_,*k_tile_iter), tEsE(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutE{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + Tensor sE_ = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + Tensor sE = as_position_independent_swizzle_tensor(sE_); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + auto copy_atom_E = Copy_Atom{}; + + Tensor tCsE = partition_E(thread_mma, sE(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrE = make_fragment_like(tCsE); // (MMA,MMA_M,MMA_K) + + auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(thread_idx); + + Tensor tEsE = smem_thr_copy_E.partition_S(sE); // (ECPY,ECPY_M,ECPY_K) + Tensor tErE = smem_thr_copy_E.retile_D(tCrE); // (ECPY,ECPY_M,ECPY_K) + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + int read_stage = smem_pipe_read.index(); + + // Load metadata smem->rmem for one stage + copy(smem_tiled_copy_E, tEsE(_,_,_,read_stage), tErE); + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block,read_stage), tErE(_,_,k_block)), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + int read_stage = smem_pipe_read.index(); + + // Load metadata smem->rmem for one stage + copy(smem_tiled_copy_E, tEsE(_,_,_,read_stage), tErE); + + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block,read_stage), tErE(_,_,k_block)), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + +private: + + template + CUTE_HOST_DEVICE static constexpr + auto + thrfrg_E(TiledMMA const& mma, ETensor&& etensor) + { + using TiledMma = TiledMMA; + + CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), + get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto e_tile = make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})), + make_layout(size<2>(typename TiledMma::AtomShape_MNK{}))); + auto e_tensor = zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout; + auto tv_tensor = e_tensor.compose(AtomLayoutE_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)), + make_layout(size<3>(mma.thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + get_layoutE_TV(TiledMMA const& mma) + { + // (M,K) -> (M,K) + auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma))); + // (ethrid,val) -> (M,K) + auto layoutE_TV = thrfrg_E(mma, ref_E); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(mma.thr_layout_vmnk_), size<2>(mma.thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + partition_E(ThrMMA const& thr_mma, ETensor&& etensor) + { + auto thr_tensor = make_tensor(static_cast(etensor).data(), thrfrg_E(thr_mma, etensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_mma.thr_vmnk_), make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_tiled_copy_E(Copy_Atom const& copy_atom, + TiledMMA const& mma) + { + return make_tiled_copy_impl(copy_atom, get_layoutE_TV(mma), make_shape(tile_size<0>(mma),tile_size<2>(mma))); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/base_grouped.h b/include/cutlass/gemm/device/base_grouped.h index 51b9d3dc10..eec61981f8 100644 --- a/include/cutlass/gemm/device/base_grouped.h +++ b/include/cutlass/gemm/device/base_grouped.h @@ -432,6 +432,7 @@ class BaseGrouped { // // Launch + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); // diff --git a/include/cutlass/gemm/device/default_gemm_configuration.h b/include/cutlass/gemm/device/default_gemm_configuration.h index c9e7cc76d1..e7ed2da940 100644 --- a/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/include/cutlass/gemm/device/default_gemm_configuration.h @@ -764,50 +764,19 @@ struct DefaultGemmConfiguration< //////////////////////////////////////////////////////////////////////////////// -/// Base configuration for all {fe4m3, fe5m2} x {fe4m3, fe5m2} combinations on SM89 template < - typename ElementA, - typename ElementB, - typename ElementC, - typename ElementAccumulator> -struct DefaultGemmConfigurationSm89F8 { - static_assert((platform::is_same::value || - platform::is_same::value), - "ElementA must be of type float_e4m3_t or float_e5m2_t"); - static_assert((platform::is_same::value || - platform::is_same::value), - "ElementB must be of type float_e4m3_t or float_e5m2_t"); - - static int const kAlignmentA = 128 / sizeof_bits::value; - static int const kAlignmentB = 128 / sizeof_bits::value; - - using ThreadblockShape = GemmShape<128, 256, 64>; - using WarpShape = GemmShape<64, 64, 64>; - using InstructionShape = GemmShape<16, 8, 32>; - static int const kStages = 3; - - using EpilogueOutputOp = epilogue::thread::LinearCombination< - ElementC, 128 / sizeof_bits::value, ElementAccumulator, - ElementAccumulator>; - - using Operator = arch::OpMultiplyAdd; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < typename ElementC> struct DefaultGemmConfiguration< - arch::OpClassTensorOp, - arch::Sm80, - int4b_t, - int8_t, - ElementC, + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + int8_t, + ElementC, int32_t> { - + static int const kAlignmentA = 128 / sizeof_bits::value; static int const kAlignmentB = 128 / sizeof_bits::value; - + using ThreadblockShape = GemmShape<128, 256, 64>; using WarpShape = GemmShape<64, 64, 64>; using InstructionShape = GemmShape<16, 8, 32>; @@ -821,19 +790,19 @@ struct DefaultGemmConfiguration< //////////////////////////////////////////////////////////////////////////////// -template < +template < typename ElementC> struct DefaultGemmConfiguration< - arch::OpClassTensorOp, - arch::Sm80, - int8_t, - int4b_t, - ElementC, + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + int4b_t, + ElementC, int32_t> { - + static int const kAlignmentA = 128 / sizeof_bits::value; static int const kAlignmentB = 128 / sizeof_bits::value; - + using ThreadblockShape = GemmShape<128, 256, 64>; using WarpShape = GemmShape<64, 64, 64>; using InstructionShape = GemmShape<16, 8, 32>; @@ -847,6 +816,35 @@ struct DefaultGemmConfiguration< //////////////////////////////////////////////////////////////////////////////// +/// Base configuration for all {fe4m3, fe5m2} x {fe4m3, fe5m2} combinations on SM89 +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementAccumulator> +struct DefaultGemmConfigurationSm89F8 { + static_assert((platform::is_same::value || + platform::is_same::value), + "ElementA must be of type float_e4m3_t or float_e5m2_t"); + static_assert((platform::is_same::value || + platform::is_same::value), + "ElementB must be of type float_e4m3_t or float_e5m2_t"); + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 128 / sizeof_bits::value, ElementAccumulator, + ElementAccumulator>; + + using Operator = arch::OpMultiplyAdd; +}; + /// Partial specialization for SM89 fe4m3 x fe4m3 template struct DefaultGemmConfiguration< diff --git a/include/cutlass/gemm/device/ell_gemm.h b/include/cutlass/gemm/device/ell_gemm.h index f5b65cea29..54ddab4007 100644 --- a/include/cutlass/gemm/device/ell_gemm.h +++ b/include/cutlass/gemm/device/ell_gemm.h @@ -517,6 +517,7 @@ class EllGemm { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index f0226354de..c6f488b146 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -491,6 +491,7 @@ class Gemm { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index 6bbd90c1cd..1ae2db467f 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -446,6 +446,7 @@ class GemmArray { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index 3be34c808d..5981457c73 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -424,6 +424,7 @@ class GemmBatched { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 36f57d6469..e36c69cefb 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -445,6 +445,7 @@ class GemmComplex { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_sparse.h b/include/cutlass/gemm/device/gemm_sparse.h index 1b1d27bda5..ac453c63b5 100644 --- a/include/cutlass/gemm/device/gemm_sparse.h +++ b/include/cutlass/gemm/device/gemm_sparse.h @@ -479,6 +479,7 @@ class SparseGemm { int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_sparse_with_absmax.h b/include/cutlass/gemm/device/gemm_sparse_with_absmax.h index e6db107604..e599217a13 100644 --- a/include/cutlass/gemm/device/gemm_sparse_with_absmax.h +++ b/include/cutlass/gemm/device/gemm_sparse_with_absmax.h @@ -324,6 +324,7 @@ class SparseGemmWithAbsmax { int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_splitk_parallel.h b/include/cutlass/gemm/device/gemm_splitk_parallel.h index 2c9408df0e..f78c5a2169 100644 --- a/include/cutlass/gemm/device/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -357,6 +357,7 @@ class GemmSplitKParallel { } } + cutlass::arch::synclog_setup(); Kernel<<>>(gemm_params_); result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 40094dcb10..73564d3c65 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -44,6 +44,7 @@ #include "cutlass/detail/mma.hpp" #include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/kernel_launch.h" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" #include "cutlass/trace.h" @@ -211,9 +212,10 @@ class GemmUniversalAdapter< workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); } + workspace_bytes += GemmKernel::get_workspace_size(args); + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - workspace_bytes += GemmKernel::get_workspace_size(args); return workspace_bytes; } @@ -350,8 +352,12 @@ class GemmUniversalAdapter< Status launch_result{ Status::kSuccess }; // Use extended launch API only for mainloops that use it if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { - constexpr bool is_static_1x1x1 = cute::is_static_v and - cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Use extended launch API"); +#endif + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); @@ -363,12 +369,14 @@ class GemmUniversalAdapter< // CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { - if (launch_with_pdl) { CUTLASS_TRACE_HOST( "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); return Status::kErrorInternal; } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif launch_result = cuda_adapter->launch(grid, cluster, block, @@ -378,6 +386,7 @@ class GemmUniversalAdapter< 0); } else { + CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); return Status::kErrorInternal; } } @@ -385,10 +394,25 @@ class GemmUniversalAdapter< CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { - if (is_static_1x1x1 && not launch_with_pdl) { - device_kernel<<>>(params); + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif } else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching dynamic cluster kernel"); +#endif launch_result = ClusterLauncher::launch( grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); } @@ -397,28 +421,48 @@ class GemmUniversalAdapter< } else { launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + if constexpr (kEnableCudaHostAdapter) { CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { void* kernel_params[] = {¶ms}; - +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif launch_result = cuda_adapter->launch( grid, block, smem_size, stream, kernel_params, 0 ); } else { + CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); return Status::kErrorInternal; } } else { CUTLASS_ASSERT(cuda_adapter == nullptr); - device_kernel<<>>(params); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif } } cudaError_t result = cudaGetLastError(); if (cudaSuccess == result && Status::kSuccess == launch_result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: cudaGetLastError reports success"); +#endif return Status::kSuccess; } else { diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 63da07b418..e23191eae5 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -443,6 +443,8 @@ class GemmUniversalBase { "block: (" << block << "), " "SMEM: (" << kSharedStorageSize << ")"); + cutlass::arch::synclog_setup(); + if constexpr (kEnableCudaHostAdapter) { CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { diff --git a/include/cutlass/gemm/device/gemv.h b/include/cutlass/gemm/device/gemv.h index 341124942a..5e181743ef 100644 --- a/include/cutlass/gemm/device/gemv.h +++ b/include/cutlass/gemm/device/gemv.h @@ -141,6 +141,7 @@ class Gemv { int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); // Launch + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); // diff --git a/include/cutlass/gemm/device/rank_2k.h b/include/cutlass/gemm/device/rank_2k.h index d12621e6b9..296f38cad2 100644 --- a/include/cutlass/gemm/device/rank_2k.h +++ b/include/cutlass/gemm/device/rank_2k.h @@ -319,6 +319,7 @@ class Rank2K { int smem_size = int(sizeof(typename Rank2Kkernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/rank_k.h b/include/cutlass/gemm/device/rank_k.h index e6e9d025a4..ae18a11b80 100644 --- a/include/cutlass/gemm/device/rank_k.h +++ b/include/cutlass/gemm/device/rank_k.h @@ -296,6 +296,7 @@ class RankK { int smem_size = int(sizeof(typename RankKkernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/symm.h b/include/cutlass/gemm/device/symm.h index 223e1b0d10..c36ef959b1 100755 --- a/include/cutlass/gemm/device/symm.h +++ b/include/cutlass/gemm/device/symm.h @@ -337,6 +337,7 @@ class Symm { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/device/trmm.h b/include/cutlass/gemm/device/trmm.h index e354e7a132..09b9152cbb 100644 --- a/include/cutlass/gemm/device/trmm.h +++ b/include/cutlass/gemm/device/trmm.h @@ -495,6 +495,7 @@ class Trmm { } } + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index c1c2308b9d..904e6af3cc 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -34,7 +34,7 @@ #include "cutlass/gemm/gemm.h" #include "cute/layout.hpp" -#include "cute/numeric/integral_constant.hpp" +#include "cute/numeric/integral_constant.hpp" // cute::false_type ////////////////////////////////////////////////////////////////////////////// namespace cutlass::detail { @@ -48,6 +48,16 @@ struct is_kernel_tag_of, U> : cute::true_type {}; template class U> constexpr bool is_kernel_tag_of_v = is_kernel_tag_of::value; +template class U> +struct is_asymmetric_dma_kernel_tag_of : cute::false_type {}; + +template