Skip to content

Commit 6e2a6ac

Browse files
committed
Rewrite mma unit tests
1 parent 8cdf476 commit 6e2a6ac

File tree

1 file changed

+90
-213
lines changed

1 file changed

+90
-213
lines changed

test/unit/cute/intel_xe/mma.cpp

Lines changed: 90 additions & 213 deletions
Original file line numberDiff line numberDiff line change
@@ -30,284 +30,161 @@
3030
*
3131
**************************************************************************************************/
3232

33-
#include "cutlass/detail/layout.hpp"
33+
#include "cutlass_unit_test.h"
3434

3535
#include <cute/tensor.hpp>
36-
#include <sycl/sycl.hpp>
37-
#include <cute/util/compat.hpp>
3836

39-
#include "cutlass_unit_test.h"
40-
#include "utils.hpp"
37+
#include "../cooperative_gemm_common.hpp"
4138

4239
using namespace cute;
43-
using namespace cutlass;
44-
using namespace compat::experimental;
45-
46-
#define SUBGROUP_SIZE (16)
47-
48-
template<class...> class GemmDeviceName;
49-
50-
template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m,
51-
uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC>
52-
void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n,
53-
uint32_t k) {
54-
using namespace cute;
55-
56-
// Represent the full tensors
57-
Tensor mA = make_tensor(make_gmem_ptr(A),
58-
make_layout(make_shape(m, k), make_stride(k, 1)));
59-
Tensor mB = make_tensor(make_gmem_ptr(B),
60-
make_layout(make_shape(n, k), make_stride(1, n)));
61-
Tensor mC = make_tensor(make_gmem_ptr(C),
62-
make_layout(make_shape(m, n), make_stride(n, 1)));
63-
64-
// Get the appropriate blocks for this thread block
65-
auto cta_coord = make_coord(BlockIdxX(), BlockIdxY(), _); // (m,n,k)
66-
67-
auto cta_tiler =
68-
make_shape(Int<wg_tile_m>{}, Int<wg_tile_n>{}, Int<sg_tile_k>{});
69-
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{});
70-
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{});
71-
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{});
72-
73-
TiledMMA mma = make_tiled_mma(
74-
MMA_Atom<MMA>{},
75-
Layout< // Require: subgroup_layout
76-
Shape<Int<cute::ceil_div(wg_tile_m, sg_tile_m)>,
77-
Int<cute::ceil_div(wg_tile_n, sg_tile_n)>, _1>>{});
78-
79-
ThrMMA thrd_mma = mma.get_slice(ThreadIdxX());
80-
81-
Tensor tgA = thrd_mma.partition_A(gA);
82-
Tensor fragment_A =
83-
thrd_mma.make_fragment_A(tgA(_, _, _, 0)); // (MMA, MMA_M, MMA_K)
84-
85-
Tensor tgB = thrd_mma.partition_B(gB);
86-
Tensor fragment_B =
87-
thrd_mma.make_fragment_B(tgB(_, _, _, 0)); // (MMA, MMA_N, MMA_K)
88-
89-
Tensor tgC = thrd_mma.partition_C(gC);
90-
Tensor fragment_C = thrd_mma.make_fragment_C(tgC); // (MMA, MMA_M, MMA_N)
91-
clear(fragment_C);
92-
93-
#define CUTLASS_ENABLE_DEBUG_PRINTS (0)
94-
95-
#undef LOG_THREAD
96-
#define LOG_THREAD (16)
97-
98-
#if CUTLASS_ENABLE_DEBUG_PRINTS
99-
if (thread(LOG_THREAD)) {
100-
print("===================== A :\n");
101-
102-
print(" mA : ");
103-
print(mA);
104-
print("\n");
105-
print(" gA : ");
106-
print(gA);
107-
print("\n");
108-
print("tgA : ");
109-
print(tgA);
110-
print("\n");
111-
print("fragment_A : ");
112-
print(fragment_A);
113-
print("\n\n");
114-
}
115-
#endif
116-
117-
#if CUTLASS_ENABLE_DEBUG_PRINTS
118-
if (thread(LOG_THREAD)) {
119-
print("===================== B :\n");
120-
121-
print(" mB : ");
122-
print(mB);
123-
print("\n");
124-
print(" gB : ");
125-
print(gB);
126-
print("\n");
127-
print("tgB : ");
128-
print(tgB);
129-
print("\n");
130-
print("fragment_B : ");
131-
print(fragment_B);
132-
print("\n\n");
133-
}
134-
#endif
135-
136-
#if CUTLASS_ENABLE_DEBUG_PRINTS
137-
if (thread(LOG_THREAD)) {
138-
print("===================== C :\n");
13940

140-
print(" mC : ");
141-
print(mC);
142-
print("\n");
143-
print(" gC : ");
144-
print(gC);
145-
print("\n");
146-
print("tgC : ");
147-
print(tgC);
148-
print("\n");
149-
print("fragment_C : ");
150-
print(fragment_C);
151-
print("\n\n");
152-
}
153-
#endif
154-
155-
auto k_tile_max = size<3>(tgA);
156-
for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) {
157-
auto kA = tgA(_, _, _, k_tile);
158-
auto kB = tgB(_, _, _, k_tile);
159-
// Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors
160-
copy(kA, fragment_A);
161-
copy(kB, fragment_B);
162-
163-
// Compute gemm on mma-partitioned smem
164-
cute::gemm(mma, fragment_A, fragment_B, fragment_C);
165-
}
166-
167-
copy(fragment_C, tgC);
41+
namespace {
42+
constexpr uint32_t thread_block_size = 128;
43+
constexpr uint32_t max_vec_bits = 128;
16844
}
16945

170-
// Setup params for a NT GEMM
171-
template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m,
172-
uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC>
173-
void gemm(int m, int n, int k, TA *A, TB *B, TC *C) {
174-
using namespace cute;
175-
176-
auto dimBlock = compat::dim3(SUBGROUP_SIZE * (wg_tile_m * wg_tile_n) /
177-
(sg_tile_m * sg_tile_n));
178-
auto dimGrid = compat::dim3(size(ceil_div(m, wg_tile_m)),
179-
size(ceil_div(n, wg_tile_n)));
180-
181-
launch<gemm_device<MMA, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k,
182-
TA, TB, TC>, GemmDeviceName<MMA, TA, TB, TC>>(
183-
launch_policy{dimGrid, dimBlock,
184-
kernel_properties{sycl_exp::sub_group_size<SUBGROUP_SIZE>}},
185-
A, B, C, m, n, k);
186-
}
187-
188-
template <class MMA, uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m,
189-
uint32_t sg_tile_n, uint32_t sg_tile_k, class TA, class TB, class TC>
190-
void MMA_Test(int m, int n, int k) {
191-
cutlass::host_vector<TA> h_A(m * k);
192-
cutlass::host_vector<TB> h_B(n * k);
193-
cutlass::host_vector<TC> h_C(m * n);
194-
195-
fill_matrix(h_A);
196-
fill_matrix(h_B);
197-
198-
cutlass::device_vector<TA> d_A = h_A;
199-
cutlass::device_vector<TB> d_B = h_B;
200-
cutlass::device_vector<TC> d_C = h_C;
201-
202-
::gemm<MMA, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k>(
203-
m, n, k, d_A.data(), d_B.data(), d_C.data());
204-
compat::wait();
205-
206-
h_C = d_C;
207-
verify(m, n, k, h_A.data(), h_B.data(), h_C.data());
46+
template<typename MMAAtom, typename LayoutShape, typename ShapeMNK,
47+
typename TA, typename TB, typename TC>
48+
void run_mma_test(ShapeMNK shape_mnk, LayoutShape layout_shape) {
49+
auto tiled_mma = TiledMMA<MMA_Atom<MMAAtom>, Layout<LayoutShape>>{};
50+
test_cooperative_gemm_col_major_layout<thread_block_size, max_vec_bits, TA, TB, TC>(
51+
shape_mnk, tiled_mma);
20852
}
20953

21054
TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32S8S8S32_TT) {
211-
MMA_Test<XE_8x16x32_S32S8S8S32_TT, 64, 64, 8, 16, 32, int8_t, int8_t,
212-
int32_t>(512, 512, 256);
55+
run_mma_test<XE_8x16x32_S32S8S8S32_TT, Shape<_2, _2, _1>,
56+
decltype(Shape<_64, _64, _32>{}), int8_t, int8_t, int32_t>(
57+
Shape<_64, _64, _32>{}, Shape<_2, _2, _1>{});
21358
}
21459

21560
TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32S8S8S32_TT) {
216-
MMA_Test<XE_4x16x32_S32S8S8S32_TT, 32, 64, 4, 16, 32, int8_t, int8_t,
217-
int32_t>(512, 512, 256);
61+
run_mma_test<XE_4x16x32_S32S8S8S32_TT, Shape<_2, _2, _1>,
62+
decltype(Shape<_32, _64, _32>{}), int8_t, int8_t, int32_t>(
63+
Shape<_32, _64, _32>{}, Shape<_2, _2, _1>{});
21864
}
21965

22066
TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32S8S8S32_TT) {
221-
MMA_Test<XE_2x16x32_S32S8S8S32_TT, 16, 64, 2, 16, 32, int8_t, int8_t,
222-
int32_t>(512, 512, 256);
67+
run_mma_test<XE_2x16x32_S32S8S8S32_TT, Shape<_4, _2, _1>,
68+
decltype(Shape<_16, _32, _32>{}), int8_t, int8_t, int32_t>(
69+
Shape<_16, _32, _32>{}, Shape<_4, _2, _1>{});
22370
}
22471

22572
TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32S8S8S32_TT) {
226-
MMA_Test<XE_1x16x32_S32S8S8S32_TT, 8, 64, 1, 16, 32, int8_t, int8_t, int32_t>(
227-
512, 512, 256);
73+
run_mma_test<XE_1x16x32_S32S8S8S32_TT, Shape<_1, _1, _1>,
74+
decltype(Shape<_8, _64, _32>{}), int8_t, int8_t, int32_t>(
75+
Shape<_8, _64, _32>{}, Shape<_1, _1, _1>{});
22876
}
22977

23078
TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32U8U8S32_TT) {
231-
MMA_Test<XE_8x16x32_S32U8U8S32_TT, 64, 64, 8, 16, 32, uint8_t, uint8_t,
232-
int32_t>(512, 512, 256);
79+
run_mma_test<XE_8x16x32_S32U8U8S32_TT, Shape<_2, _2, _1>,
80+
decltype(Shape<_64, _64, _32>{}), uint8_t, uint8_t, int32_t>(
81+
Shape<_64, _64, _32>{}, Shape<_2, _2, _1>{});
23382
}
23483

23584
TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32U8U8S32_TT) {
236-
MMA_Test<XE_4x16x32_S32U8U8S32_TT, 32, 64, 4, 16, 32, uint8_t, uint8_t,
237-
int32_t>(512, 512, 256);
85+
run_mma_test<XE_4x16x32_S32U8U8S32_TT, Shape<_2, _2, _1>,
86+
decltype(Shape<_32, _64, _32>{}), uint8_t, uint8_t, int32_t>(
87+
Shape<_32, _64, _32>{}, Shape<_2, _2, _1>{});
23888
}
23989

24090
TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32U8U8S32_TT) {
241-
MMA_Test<XE_2x16x32_S32U8U8S32_TT, 16, 64, 2, 16, 32, uint8_t, uint8_t,
242-
int32_t>(512, 512, 256);
91+
run_mma_test<XE_2x16x32_S32U8U8S32_TT, Shape<_4, _2, _1>,
92+
decltype(Shape<_16, _32, _32>{}), uint8_t, uint8_t, int32_t>(
93+
Shape<_16, _32, _32>{}, Shape<_4, _2, _1>{});
24394
}
24495

24596
TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32U8U8S32_TT) {
246-
MMA_Test<XE_1x16x32_S32U8U8S32_TT, 8, 64, 1, 16, 32, uint8_t, uint8_t,
247-
int32_t>(512, 512, 256);
97+
run_mma_test<XE_1x16x32_S32U8U8S32_TT, Shape<_1, _1, _1>,
98+
decltype(Shape<_8, _64, _32>{}), uint8_t, uint8_t, int32_t>(
99+
Shape<_8, _64, _32>{}, Shape<_1, _1, _1>{});
248100
}
249101

250102
TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32BF16BF16F32_TT) {
251-
MMA_Test<XE_8x16x16_F32BF16BF16F32_TT, 256, 256, 32, 64, 32, bfloat16_t,
252-
bfloat16_t, float>(512, 512, 256);
103+
run_mma_test<XE_8x16x16_F32BF16BF16F32_TT, Shape<_2, _2, _1>,
104+
decltype(Shape<_64, _64, _16>{}),
105+
cutlass::bfloat16_t, cutlass::bfloat16_t, float>(
106+
Shape<_64, _64, _16>{}, Shape<_2, _2, _1>{});
253107
}
254108

255109
TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32BF16BF16F32_TT) {
256-
MMA_Test<XE_4x16x16_F32BF16BF16F32_TT, 32, 64, 4, 16, 16, bfloat16_t,
257-
bfloat16_t, float>(512, 512, 256);
110+
run_mma_test<XE_4x16x16_F32BF16BF16F32_TT, Shape<_2, _2, _1>,
111+
decltype(Shape<_32, _64, _16>{}),
112+
cutlass::bfloat16_t, cutlass::bfloat16_t, float>(
113+
Shape<_32, _64, _16>{}, Shape<_2, _2, _1>{});
258114
}
259115

260116
TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32BF16BF16F32_TT) {
261-
MMA_Test<XE_2x16x16_F32BF16BF16F32_TT, 16, 64, 2, 16, 16, bfloat16_t,
262-
bfloat16_t, float>(512, 512, 256);
117+
run_mma_test<XE_2x16x16_F32BF16BF16F32_TT, Shape<_2, _4, _1>,
118+
decltype(Shape<_128, _128, _16>{}),
119+
cutlass::bfloat16_t, cutlass::bfloat16_t, float>(
120+
Shape<_128, _128, _16>{}, Shape<_2, _4, _1>{});
263121
}
264122

265123
TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32BF16BF16F32_TT) {
266-
MMA_Test<XE_1x16x16_F32BF16BF16F32_TT, 8, 64, 1, 16, 16, bfloat16_t,
267-
bfloat16_t, float>(512, 512, 256);
124+
run_mma_test<XE_1x16x16_F32BF16BF16F32_TT, Shape<_1, _1, _1>,
125+
decltype(Shape<_8, _64, _16>{}),
126+
cutlass::bfloat16_t, cutlass::bfloat16_t, float>(
127+
Shape<_8, _64, _16>{}, Shape<_1, _1, _1>{});
268128
}
269129

270130
TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32F16F16F32_TT) {
271-
MMA_Test<XE_8x16x16_F32F16F16F32_TT, 64, 64, 8, 16, 16, half_t, half_t,
272-
float>(512, 512, 256);
131+
run_mma_test<XE_8x16x16_F32F16F16F32_TT, Shape<_2, _2, _1>,
132+
decltype(Shape<_64, _64, _16>{}),
133+
cutlass::half_t, cutlass::half_t, float>(
134+
Shape<_64, _64, _16>{}, Shape<_2, _2, _1>{});
273135
}
274136

275137
TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32F16F16F32_TT) {
276-
MMA_Test<XE_4x16x16_F32F16F16F32_TT, 32, 64, 4, 16, 16, half_t, half_t,
277-
float>(512, 512, 256);
138+
run_mma_test<XE_4x16x16_F32F16F16F32_TT, Shape<_2, _2, _1>,
139+
decltype(Shape<_32, _64, _16>{}),
140+
cutlass::half_t, cutlass::half_t, float>(
141+
Shape<_32, _64, _16>{}, Shape<_2, _2, _1>{});
278142
}
279143

280144
TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32F16F16F32_TT) {
281-
MMA_Test<XE_2x16x16_F32F16F16F32_TT, 16, 64, 2, 16, 16, half_t, half_t,
282-
float>(512, 512, 256);
145+
run_mma_test<XE_2x16x16_F32F16F16F32_TT, Shape<_4, _2, _1>,
146+
decltype(Shape<_128, _128, _16>{}),
147+
cutlass::half_t, cutlass::half_t, float>(
148+
Shape<_128, _128, _16>{}, Shape<_4, _2, _1>{});
283149
}
284150

285151
TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32F16F16F32_TT) {
286-
MMA_Test<XE_1x16x16_F32F16F16F32_TT, 8, 64, 1, 16, 16, half_t, half_t, float>(
287-
512, 512, 256);
152+
run_mma_test<XE_1x16x16_F32F16F16F32_TT, Shape<_1, _1, _1>,
153+
decltype(Shape<_128, _128, _16>{}),
154+
cutlass::half_t, cutlass::half_t, float>(
155+
Shape<_128, _128, _16>{}, Shape<_1, _1, _1>{});
288156
}
289157

290-
TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) {
291-
MMA_Test<UniversalFMA<float, float, float, float>, 64, 64, 8, 16, 16, float,
292-
float, float>(512, 512, 256);
158+
TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) {
159+
run_mma_test<XE_8x16x8_F32TF32TF32F32_TT, Shape<_2, _2, _1>,
160+
decltype(Shape<_64, _64, _8>{}),
161+
cutlass::tfloat32_t, cutlass::tfloat32_t, float>(
162+
Shape<_64, _64, _8>{}, Shape<_2, _2, _1>{});
293163
}
294164

295-
TEST(PVC_CuTe_Xe, MMA_XE_1x16x8_F32TF32TF32F32_TT) {
296-
MMA_Test<XE_1x16x8_F32TF32TF32F32_TT, 64, 64, 8, 16, 16, tfloat32_t,
297-
tfloat32_t, float>(512, 512, 256);
165+
TEST(PVC_CuTe_Xe, MMA_XE_4x16x8_F32TF32TF32F32_TT) {
166+
run_mma_test<XE_4x16x8_F32TF32TF32F32_TT, Shape<_2, _2, _1>,
167+
decltype(Shape<_32, _64, _8>{}),
168+
cutlass::tfloat32_t, cutlass::tfloat32_t, float>(
169+
Shape<_32, _64, _8>{}, Shape<_2, _2, _1>{});
298170
}
299171

300172
TEST(PVC_CuTe_Xe, MMA_XE_2x16x8_F32TF32TF32F32_TT) {
301-
MMA_Test<XE_2x16x8_F32TF32TF32F32_TT, 64, 64, 8, 16, 16, tfloat32_t,
302-
tfloat32_t, float>(512, 512, 256);
173+
run_mma_test<XE_2x16x8_F32TF32TF32F32_TT, Shape<_4, _2, _1>,
174+
decltype(Shape<_128, _128, _8>{}),
175+
cutlass::tfloat32_t, cutlass::tfloat32_t, float>(
176+
Shape<_128, _128, _8>{}, Shape<_4, _2, _1>{});
303177
}
304178

305-
TEST(PVC_CuTe_Xe, MMA_XE_4x16x8_F32TF32TF32F32_TT) {
306-
MMA_Test<XE_4x16x8_F32TF32TF32F32_TT, 64, 64, 8, 16, 16, tfloat32_t,
307-
tfloat32_t, float>(512, 512, 256);
179+
TEST(PVC_CuTe_Xe, MMA_XE_1x16x8_F32TF32TF32F32_TT) {
180+
run_mma_test<XE_1x16x8_F32TF32TF32F32_TT, Shape<_1, _1, _1>,
181+
decltype(Shape<_8, _64, _8>{}),
182+
cutlass::tfloat32_t, cutlass::tfloat32_t, float>(
183+
Shape<_8, _64, _8>{}, Shape<_1, _1, _1>{});
308184
}
309185

310-
TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) {
311-
MMA_Test<XE_8x16x8_F32TF32TF32F32_TT, 64, 64, 8, 16, 32, tfloat32_t,
312-
tfloat32_t, float>(512, 512, 256);
186+
TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) {
187+
run_mma_test<UniversalFMA<float, float, float, float>, Shape<_1, _1, _1>,
188+
decltype(Shape<_128, _128, _8>{}), float, float, float>(
189+
Shape<_128, _128, _8>{}, Shape<_1, _1, _1>{});
313190
}

0 commit comments

Comments
 (0)