Skip to content

Commit 86d542f

Browse files
authored
[CK-Tile][Async gemm] add missing sync and f8 inputs test cases (#3000)
* add missing sync and f8 test cases * reformat test cases * comment failing cases * bump * reintroduce compv4 shapes
1 parent 0584399 commit 86d542f

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
472472
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
473473
}
474474
{
475+
// write to LDS window(0) must complete before the local prefetch
476+
block_sync_lds_direct_load();
475477
// read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0)
476478
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
477479
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);

test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,59 @@ using KernelTypesCompV3Wmma = ::testing::Types<
124124
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
125125
>;
126126

127-
using KernelTypesCompV4 = ::testing::Types<
128-
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
129-
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
130-
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
131-
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
132-
>;
127+
// clang-format on
128+
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
129+
using CompV4Config = std::tuple<ALayout,
130+
BLayout,
131+
CLayout,
132+
InputType, // AType
133+
InputType, // BType
134+
F32, // AccType
135+
F16, // OutputType
136+
I256, // MBlockTileSize
137+
I256, // NBlockTileSize
138+
I32, // KBlockTileSize
139+
I32, // MWarpTileSize
140+
I32, // NWarpTileSize
141+
I16, // KWarpTileSize
142+
Intrawave,
143+
CompV4>;
144+
145+
using KernelTypesCompV4 = ::testing::Types<CompV4Config<Row, Row, Row, F16>,
146+
CompV4Config<Row, Col, Row, F16>,
147+
CompV4Config<Col, Row, Row, F16>,
148+
CompV4Config<Col, Col, Row, F16>,
149+
CompV4Config<Row, Row, Row, F8>,
150+
CompV4Config<Row, Col, Row, F8>,
151+
CompV4Config<Col, Row, Row, F8>,
152+
CompV4Config<Col, Col, Row, F8>>;
153+
154+
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
155+
using CompAsyncConfig = std::tuple<ALayout,
156+
BLayout,
157+
CLayout,
158+
InputType, // AType
159+
InputType, // BType
160+
F32, // AccType
161+
F16, // OutputType
162+
I256, // MBlockTileSize
163+
I256, // NBlockTileSize
164+
I32, // KBlockTileSize
165+
I32, // MWarpTileSize
166+
I32, // NWarpTileSize
167+
I16, // KWarpTileSize
168+
Intrawave,
169+
CompAsync>;
170+
171+
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
172+
CompAsyncConfig<Row, Col, Row, F16>,
173+
CompAsyncConfig<Col, Row, Row, F16>,
174+
CompAsyncConfig<Col, Col, Row, F16>,
175+
CompAsyncConfig<Row, Row, Row, F8>,
176+
CompAsyncConfig<Row, Col, Row, F8>,
177+
CompAsyncConfig<Col, Row, Row, F8>,
178+
CompAsyncConfig<Col, Col, Row, F8>>;
179+
// clang-format off
133180

134181
using KernelTypesCompV6 = ::testing::Types<
135182
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
@@ -153,12 +200,6 @@ using KernelTypesCompV6 = ::testing::Types<
153200
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
154201
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>
155202
>;
156-
using KernelTypesCompAsync = ::testing::Types<
157-
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
158-
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
159-
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
160-
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>
161-
>;
162203

163204
using KernelTypesCompV4Wmma = ::testing::Types<
164205
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,

0 commit comments

Comments
 (0)