Skip to content

Commit 6dbee64

Browse files
authored
[CK_BUILDER] Add backward weight instance traits for xdl cshuffle. (#3143)
* Add backward weight instance traits for xdl cshuffle. To keep instance test file sizes reasonable, we start a new test_bwd_weight_instances_traits.cpp test file. * Fix copyright notices. * Remove (c) symbol, replace with (C). Having UTF-8 in source caused an error with code generation.
1 parent 8681ced commit 6dbee64

File tree

6 files changed

+520
-3
lines changed

6 files changed

+520
-3
lines changed
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#pragma once
5+
6+
#include "instance_traits.hpp"
7+
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
8+
9+
// Forward declaration to avoid circular dependency
10+
namespace ck::tensor_operation::device {
11+
12+
template <ck::index_t NDimSpatial,
13+
typename InLayout,
14+
typename WeiLayout,
15+
typename OutLayout,
16+
typename InDataType,
17+
typename WeiDataType,
18+
typename OutDataType,
19+
typename AccDataType,
20+
typename InElementwiseOperation,
21+
typename WeiElementwiseOperation,
22+
typename OutElementwiseOperation,
23+
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
24+
ConvBackwardWeightSpecialization,
25+
ck::index_t BlockSize,
26+
ck::index_t MPerBlock,
27+
ck::index_t NPerBlock,
28+
ck::index_t K0PerBlock,
29+
ck::index_t K1,
30+
ck::index_t MPerXDL,
31+
ck::index_t NPerXDL,
32+
ck::index_t MXdlPerWave,
33+
ck::index_t NXdlPerWave,
34+
typename ABlockTransferThreadClusterLengths_K0_M_K1,
35+
typename ABlockTransferThreadClusterArrangeOrder,
36+
typename ABlockTransferSrcAccessOrder,
37+
ck::index_t ABlockTransferSrcVectorDim,
38+
ck::index_t ABlockTransferSrcScalarPerVector,
39+
ck::index_t ABlockTransferDstScalarPerVector_K1,
40+
bool ABlockLdsAddExtraM,
41+
typename BBlockTransferThreadClusterLengths_K0_N_K1,
42+
typename BBlockTransferThreadClusterArrangeOrder,
43+
typename BBlockTransferSrcAccessOrder,
44+
ck::index_t BBlockTransferSrcVectorDim,
45+
ck::index_t BBlockTransferSrcScalarPerVector,
46+
ck::index_t BBlockTransferDstScalarPerVector_K1,
47+
bool BBlockLdsAddExtraN,
48+
ck::index_t CShuffleMXdlPerWavePerShuffle,
49+
ck::index_t CShuffleNXdlPerWavePerShuffle,
50+
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
51+
ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
52+
typename ComputeTypeA,
53+
typename ComputeTypeB,
54+
ck::index_t MaxTransposeTransferSrcScalarPerVector,
55+
ck::index_t MaxTransposeTransferDstScalarPerVector>
56+
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle;
57+
58+
} // namespace ck::tensor_operation::device
59+
60+
namespace ck_tile {
61+
namespace reflect {
62+
63+
template <ck::index_t NDimSpatial,
64+
typename InLayout_,
65+
typename WeiLayout_,
66+
typename OutLayout_,
67+
typename InDataType_,
68+
typename WeiDataType_,
69+
typename OutDataType_,
70+
typename AccDataType_,
71+
typename InElementwiseOperation_,
72+
typename WeiElementwiseOperation_,
73+
typename OutElementwiseOperation_,
74+
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
75+
ConvBackwardWeightSpecialization,
76+
ck::index_t BlockSize,
77+
ck::index_t MPerBlock,
78+
ck::index_t NPerBlock,
79+
ck::index_t K0PerBlock,
80+
ck::index_t K1,
81+
ck::index_t MPerXDL,
82+
ck::index_t NPerXDL,
83+
ck::index_t MXdlPerWave,
84+
ck::index_t NXdlPerWave,
85+
typename ABlockTransferThreadClusterLengths_K0_M_K1_,
86+
typename ABlockTransferThreadClusterArrangeOrder_,
87+
typename ABlockTransferSrcAccessOrder_,
88+
ck::index_t ABlockTransferSrcVectorDim,
89+
ck::index_t ABlockTransferSrcScalarPerVector,
90+
ck::index_t ABlockTransferDstScalarPerVector_K1,
91+
bool ABlockLdsAddExtraM,
92+
typename BBlockTransferThreadClusterLengths_K0_N_K1_,
93+
typename BBlockTransferThreadClusterArrangeOrder_,
94+
typename BBlockTransferSrcAccessOrder_,
95+
ck::index_t BBlockTransferSrcVectorDim,
96+
ck::index_t BBlockTransferSrcScalarPerVector,
97+
ck::index_t BBlockTransferDstScalarPerVector_K1,
98+
bool BBlockLdsAddExtraN,
99+
ck::index_t CShuffleMXdlPerWavePerShuffle,
100+
ck::index_t CShuffleNXdlPerWavePerShuffle,
101+
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
102+
ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
103+
typename ComputeTypeA_,
104+
typename ComputeTypeB_,
105+
ck::index_t MaxTransposeTransferSrcScalarPerVector,
106+
ck::index_t MaxTransposeTransferDstScalarPerVector>
107+
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
108+
NDimSpatial,
109+
InLayout_,
110+
WeiLayout_,
111+
OutLayout_,
112+
InDataType_,
113+
WeiDataType_,
114+
OutDataType_,
115+
AccDataType_,
116+
InElementwiseOperation_,
117+
WeiElementwiseOperation_,
118+
OutElementwiseOperation_,
119+
ConvBackwardWeightSpecialization,
120+
BlockSize,
121+
MPerBlock,
122+
NPerBlock,
123+
K0PerBlock,
124+
K1,
125+
MPerXDL,
126+
NPerXDL,
127+
MXdlPerWave,
128+
NXdlPerWave,
129+
ABlockTransferThreadClusterLengths_K0_M_K1_,
130+
ABlockTransferThreadClusterArrangeOrder_,
131+
ABlockTransferSrcAccessOrder_,
132+
ABlockTransferSrcVectorDim,
133+
ABlockTransferSrcScalarPerVector,
134+
ABlockTransferDstScalarPerVector_K1,
135+
ABlockLdsAddExtraM,
136+
BBlockTransferThreadClusterLengths_K0_N_K1_,
137+
BBlockTransferThreadClusterArrangeOrder_,
138+
BBlockTransferSrcAccessOrder_,
139+
BBlockTransferSrcVectorDim,
140+
BBlockTransferSrcScalarPerVector,
141+
BBlockTransferDstScalarPerVector_K1,
142+
BBlockLdsAddExtraN,
143+
CShuffleMXdlPerWavePerShuffle,
144+
CShuffleNXdlPerWavePerShuffle,
145+
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
146+
CBlockTransferScalarPerVector_NWaveNPerXdl,
147+
ComputeTypeA_,
148+
ComputeTypeB_,
149+
MaxTransposeTransferSrcScalarPerVector,
150+
MaxTransposeTransferDstScalarPerVector>>
151+
{
152+
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
153+
154+
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
155+
156+
using InLayout = InLayout_;
157+
using WeiLayout = WeiLayout_;
158+
using OutLayout = OutLayout_;
159+
160+
using InDataType = InDataType_;
161+
using WeiDataType = WeiDataType_;
162+
using OutDataType = OutDataType_;
163+
using AccDataType = AccDataType_;
164+
165+
using InElementwiseOperation = InElementwiseOperation_;
166+
using WeiElementwiseOperation = WeiElementwiseOperation_;
167+
using OutElementwiseOperation = OutElementwiseOperation_;
168+
169+
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
170+
171+
static constexpr ck::index_t kBlockSize = BlockSize;
172+
static constexpr ck::index_t kMPerBlock = MPerBlock;
173+
static constexpr ck::index_t kNPerBlock = NPerBlock;
174+
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
175+
static constexpr ck::index_t kK1 = K1;
176+
static constexpr ck::index_t kMPerXDL = MPerXDL;
177+
static constexpr ck::index_t kNPerXDL = NPerXDL;
178+
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
179+
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
180+
181+
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
182+
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
183+
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
184+
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
185+
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
186+
ABlockTransferSrcScalarPerVector;
187+
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
188+
ABlockTransferDstScalarPerVector_K1;
189+
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
190+
191+
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
192+
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
193+
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
194+
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
195+
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
196+
BBlockTransferSrcScalarPerVector;
197+
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
198+
BBlockTransferDstScalarPerVector_K1;
199+
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
200+
201+
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
202+
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
203+
204+
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
205+
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
206+
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
207+
CBlockTransferScalarPerVector_NWaveNPerXdl;
208+
209+
using ComputeTypeA = ComputeTypeA_;
210+
using ComputeTypeB = ComputeTypeB_;
211+
212+
static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector =
213+
MaxTransposeTransferSrcScalarPerVector;
214+
static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector =
215+
MaxTransposeTransferDstScalarPerVector;
216+
217+
// Static member function to generate instance string
218+
static std::string instance_string()
219+
{
220+
std::ostringstream oss;
221+
222+
// Kernel type name
223+
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
224+
225+
// Template parameters in exact order
226+
oss << "<" << kNDimSpatial; // 1. NDimSpatial
227+
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
228+
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
229+
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
230+
oss << "," << detail::type_name<InDataType>(); // 5. InDataType
231+
oss << "," << detail::type_name<WeiDataType>(); // 6. WeiDataType
232+
oss << "," << detail::type_name<OutDataType>(); // 7. OutDataType
233+
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
234+
oss << ","
235+
<< detail::elementwise_op_name<InElementwiseOperation>(); // 9. InElementwiseOperation
236+
oss << ","
237+
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 10.
238+
// WeiElementwiseOperation
239+
oss << ","
240+
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
241+
// OutElementwiseOperation
242+
oss << ","
243+
<< detail::conv_bwd_weight_spec_name(
244+
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
245+
oss << "," << kBlockSize; // 13. BlockSize
246+
oss << "," << kMPerBlock; // 14. MPerBlock
247+
oss << "," << kNPerBlock; // 15. NPerBlock
248+
oss << "," << kK0PerBlock; // 16. K0PerBlock
249+
oss << "," << kK1; // 17. K1
250+
oss << "," << kMPerXDL; // 18. MPerXDL
251+
oss << "," << kNPerXDL; // 19. NPerXDL
252+
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
253+
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
254+
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
255+
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
256+
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
257+
oss << "," << kABlockTransferSrcVectorDim; // 25.
258+
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
259+
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
260+
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
261+
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
262+
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
263+
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
264+
oss << "," << kBBlockTransferSrcVectorDim; // 32.
265+
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
266+
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
267+
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
268+
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
269+
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
270+
oss << ","
271+
<< detail::sequence_name<
272+
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38.
273+
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39.
274+
oss << "," << detail::type_name<ComputeTypeA>(); // 40.
275+
oss << "," << detail::type_name<ComputeTypeB>(); // 41.
276+
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 42.
277+
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 43.
278+
oss << ">";
279+
280+
return oss.str();
281+
}
282+
};
283+
284+
} // namespace reflect
285+
} // namespace ck_tile

experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
12
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
// Utility functions and helpers for instance_traits.hpp
55
// Contains helper functions to convert types, enums, and sequences to string representations.
@@ -21,6 +21,7 @@
2121
#include <ck_tile/ops/common/tensor_layout.hpp>
2222
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
2323
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
24+
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
2425
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
2526

2627
namespace ck_tile::reflect::detail {
@@ -112,6 +113,20 @@ conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecializatio
112113
}
113114
}
114115

116+
// Convert ConvolutionBackwardWeightSpecialization enum to string
117+
constexpr std::string_view conv_bwd_weight_spec_name(
118+
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization spec)
119+
{
120+
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
121+
switch(spec)
122+
{
123+
case Default: return "Default";
124+
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
125+
case Filter1x1Pad0: return "Filter1x1Pad0";
126+
case OddC: return "OddC";
127+
}
128+
}
129+
115130
// Convert GemmSpecialization enum to string
116131
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
117132
{

experimental/builder/test/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ endfunction()
2020
add_ck_builder_test(test_ckb_conv_builder
2121
test_conv_builder.cpp
2222
test_fwd_instance_traits.cpp
23+
test_bwd_weight_instance_traits.cpp
2324
test_instance_traits_util.cpp)
2425

2526
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
@@ -30,7 +31,8 @@ add_ck_builder_test(test_ckb_get_instance_string
3031
test_get_instance_string_fwd_grp_conv.cpp
3132
test_get_instance_string_fwd_grp_conv_large_tensor.cpp
3233
test_get_instance_string_fwd_grp_conv_wmma.cpp
33-
test_get_instance_string_fwd_grp_conv_dl.cpp)
34+
test_get_instance_string_fwd_grp_conv_dl.cpp
35+
test_get_instance_string_bwd_weight_grp_conv_xdl.cpp)
3436

3537
# Testing the fwd convolution builder requires kernel compilation.
3638
# To enable parallel compilation, the individual tests are split into separate files.

0 commit comments

Comments
 (0)