@@ -124,12 +124,59 @@ using KernelTypesCompV3Wmma = ::testing::Types<
124124 std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
125125>;
126126
127- using KernelTypesCompV4 = ::testing::Types<
128- std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
129- std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
130- std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
131- std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
132- >;
127+ // clang-format on
128+ template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
129+ using CompV4Config = std::tuple<ALayout,
130+ BLayout,
131+ CLayout,
132+ InputType, // AType
133+ InputType, // BType
134+ F32, // AccType
135+ F16, // OutputType
136+ I256, // MBlockTileSize
137+ I256, // NBlockTileSize
138+ I32, // KBlockTileSize
139+ I32, // MWarpTileSize
140+ I32, // NWarpTileSize
141+ I16, // KWarpTileSize
142+ Intrawave,
143+ CompV4>;
144+
145+ using KernelTypesCompV4 = ::testing::Types<CompV4Config<Row, Row, Row, F16>,
146+ CompV4Config<Row, Col, Row, F16>,
147+ CompV4Config<Col, Row, Row, F16>,
148+ CompV4Config<Col, Col, Row, F16>,
149+ CompV4Config<Row, Row, Row, F8>,
150+ CompV4Config<Row, Col, Row, F8>,
151+ CompV4Config<Col, Row, Row, F8>,
152+ CompV4Config<Col, Col, Row, F8>>;
153+
154+ template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
155+ using CompAsyncConfig = std::tuple<ALayout,
156+ BLayout,
157+ CLayout,
158+ InputType, // AType
159+ InputType, // BType
160+ F32, // AccType
161+ F16, // OutputType
162+ I256, // MBlockTileSize
163+ I256, // NBlockTileSize
164+ I32, // KBlockTileSize
165+ I32, // MWarpTileSize
166+ I32, // NWarpTileSize
167+ I16, // KWarpTileSize
168+ Intrawave,
169+ CompAsync>;
170+
171+ using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
172+ CompAsyncConfig<Row, Col, Row, F16>,
173+ CompAsyncConfig<Col, Row, Row, F16>,
174+ CompAsyncConfig<Col, Col, Row, F16>,
175+ CompAsyncConfig<Row, Row, Row, F8>,
176+ CompAsyncConfig<Row, Col, Row, F8>,
177+ CompAsyncConfig<Col, Row, Row, F8>,
178+ CompAsyncConfig<Col, Col, Row, F8>>;
179+ // clang-format off
133180
134181using KernelTypesCompV6 = ::testing::Types<
135182 std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
@@ -153,12 +200,6 @@ using KernelTypesCompV6 = ::testing::Types<
153200 std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>,
154201 std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>
155202>;
156- using KernelTypesCompAsync = ::testing::Types<
157- std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
158- std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
159- std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>,
160- std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>
161- >;
162203
163204using KernelTypesCompV4Wmma = ::testing::Types<
164205 std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
0 commit comments