Skip to content

Commit 36b051a

Browse files
committed
Add missing precision type combinations to CompV3Wmma from CompV3
1 parent f9078b4 commit 36b051a

File tree

6 files changed

+69
-6
lines changed

6 files changed

+69
-6
lines changed

test/ck_tile/gemm/CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,16 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
7171
endif()
7272

7373
if(GPU_TARGETS MATCHES "gfx11|gfx12")
74-
# On Radeon devices, build the WMMA version instead
74+
# On Radeon devices, build the WMMA version instead
75+
# Define architecture macros for compile-time detection
76+
if(GPU_TARGETS MATCHES "gfx12")
77+
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DARCH_GFX12)
78+
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DARCH_GFX12)
79+
elseif(GPU_TARGETS MATCHES "gfx11")
80+
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DARCH_GFX11)
81+
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DARCH_GFX11)
82+
endif()
83+
7584
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
7685
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
7786
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)

test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,21 @@ template <typename T>
99
class TestCkTileGemmPipelineCompV3Wmma
1010
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV3Wmma<T>>
1111
{
12+
public:
13+
static constexpr bool check_data_type()
14+
{
15+
using Base1 = TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV3Wmma<T>>;
16+
using Base2 = TestCkTileGemmPipeline<T, Base1>;
17+
if constexpr(std::is_same_v<typename Base2::BLayout, Row> &&
18+
std::is_same_v<typename Base2::BDataType, I4>)
19+
{
20+
return false;
21+
}
22+
else
23+
{
24+
return Base1::check_data_type();
25+
}
26+
}
1227
};
1328

1429
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3Wmma

test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,25 +122,45 @@ using KernelTypesCompV3 = ::testing::Types<
122122

123123
using KernelTypesCompV3Wmma = ::testing::Types<
124124
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
125+
std::tuple< Row, Row, Row, F16, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
125126
std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
127+
std::tuple< Row, Row, Row, BF16, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
126128
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
127129
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
130+
std::tuple< Row, Row, Row, F8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
131+
std::tuple< Row, Row, Row, F8, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
128132
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
133+
std::tuple< Row, Row, Row, BF8, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
129134
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
135+
std::tuple< Row, Col, Row, F16, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
130136
std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
137+
std::tuple< Row, Col, Row, BF16, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
131138
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
132139
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
140+
std::tuple< Row, Col, Row, F8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
141+
std::tuple< Row, Col, Row, F8, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
133142
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
143+
std::tuple< Row, Col, Row, BF8, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
134144
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
145+
std::tuple< Col, Row, Row, F16, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
135146
std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
147+
std::tuple< Col, Row, Row, BF16, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
136148
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
137149
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
150+
std::tuple< Col, Row, Row, F8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
151+
std::tuple< Col, Row, Row, F8, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
138152
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
153+
std::tuple< Col, Row, Row, BF8, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
139154
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
155+
std::tuple< Col, Col, Row, F16, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
140156
std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
157+
std::tuple< Col, Col, Row, BF16, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
141158
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
142159
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
143-
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
160+
std::tuple< Col, Col, Row, F8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
161+
std::tuple< Col, Col, Row, F8, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
162+
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
163+
std::tuple< Col, Col, Row, BF8, I4, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
144164
>;
145165

146166
using KernelTypesCompV4 = ::testing::Types<

test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#pragma once
55

6+
#include "ck_tile/core/arch/arch.hpp"
7+
68
TYPED_TEST(TEST_SUITE_NAME, SmallM)
79
{
810
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
@@ -103,7 +105,11 @@ TYPED_TEST(TEST_SUITE_NAME, PaddK)
103105
{
104106
if constexpr(std::is_same_v<typename TestFixture::BDataType, ck_tile::pk_int4_t>)
105107
{
108+
#if defined(ARCH_GFX12) || defined(ARCH_GFX11)
109+
this->Run(M, N, K);
110+
#else
106111
EXPECT_THROW(this->Run(M, N, K), std::runtime_error);
112+
#endif
107113
}
108114
else
109115
{

test/ck_tile/gemm/test_gemm_pipeline_util.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompAsync, Problem>
101101
template <typename Tuple, typename Derived>
102102
class TestCkTileGemmPipeline : public ::testing::Test
103103
{
104-
protected:
104+
public:
105105
using ALayout = std::tuple_element_t<0, Tuple>;
106106
using BLayout = std::tuple_element_t<1, Tuple>;
107107
using CLayout = std::tuple_element_t<2, Tuple>;
@@ -126,6 +126,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
126126
static constexpr bool Persistent =
127127
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
128128

129+
protected:
129130
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
130131
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
131132
{

test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,23 @@ class TestCkTileGemmPipelineWmmaBase : public TestCkTileGemmPipeline<Tuple, Deri
1313
public:
1414
static constexpr bool check_data_type()
1515
{
16-
using Base = TestCkTileGemmPipeline<Tuple, Derived>;
17-
using DeviceIp = ck_tile::remove_cvref_t<decltype(ck_tile::get_device_arch())>;
16+
using Base = TestCkTileGemmPipeline<Tuple, Derived>;
17+
18+
#if defined(ARCH_GFX12)
19+
using DeviceIp = ck_tile::gfx12_t;
20+
#elif defined(ARCH_GFX11)
21+
using DeviceIp = ck_tile::gfx11_t;
22+
#else
23+
#error "Unsupported architecture for WMMA"
24+
#endif
25+
26+
using BTypeToUse =
27+
std::conditional_t<std::is_same_v<typename Base::BDataType, ck_tile::pk_int4_t>,
28+
typename Base::ADataType,
29+
typename Base::BDataType>;
1830
return ck_tile::has_wmma_traits_v<DeviceIp,
1931
typename Base::ADataType,
20-
typename Base::BDataType,
32+
BTypeToUse,
2133
typename Base::AccDataType,
2234
ck_tile::constant<Base::M_Warp_Tile>::value,
2335
ck_tile::constant<Base::N_Warp_Tile>::value,

0 commit comments

Comments
 (0)