|
| 1 | +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. |
| 2 | +// SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +#pragma once |
| 5 | + |
| 6 | +#include "instance_traits.hpp" |
| 7 | +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" |
| 8 | + |
| 9 | +// Forward declaration to avoid circular dependency |
| 10 | +namespace ck::tensor_operation::device { |
| 11 | + |
| 12 | +template <ck::index_t NDimSpatial, |
| 13 | + typename InLayout, |
| 14 | + typename WeiLayout, |
| 15 | + typename OutLayout, |
| 16 | + typename InDataType, |
| 17 | + typename WeiDataType, |
| 18 | + typename OutDataType, |
| 19 | + typename AccDataType, |
| 20 | + typename InElementwiseOperation, |
| 21 | + typename WeiElementwiseOperation, |
| 22 | + typename OutElementwiseOperation, |
| 23 | + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization |
| 24 | + ConvBackwardWeightSpecialization, |
| 25 | + ck::index_t BlockSize, |
| 26 | + ck::index_t MPerBlock, |
| 27 | + ck::index_t NPerBlock, |
| 28 | + ck::index_t K0PerBlock, |
| 29 | + ck::index_t K1, |
| 30 | + ck::index_t MPerXDL, |
| 31 | + ck::index_t NPerXDL, |
| 32 | + ck::index_t MXdlPerWave, |
| 33 | + ck::index_t NXdlPerWave, |
| 34 | + typename ABlockTransferThreadClusterLengths_K0_M_K1, |
| 35 | + typename ABlockTransferThreadClusterArrangeOrder, |
| 36 | + typename ABlockTransferSrcAccessOrder, |
| 37 | + ck::index_t ABlockTransferSrcVectorDim, |
| 38 | + ck::index_t ABlockTransferSrcScalarPerVector, |
| 39 | + ck::index_t ABlockTransferDstScalarPerVector_K1, |
| 40 | + bool ABlockLdsAddExtraM, |
| 41 | + typename BBlockTransferThreadClusterLengths_K0_N_K1, |
| 42 | + typename BBlockTransferThreadClusterArrangeOrder, |
| 43 | + typename BBlockTransferSrcAccessOrder, |
| 44 | + ck::index_t BBlockTransferSrcVectorDim, |
| 45 | + ck::index_t BBlockTransferSrcScalarPerVector, |
| 46 | + ck::index_t BBlockTransferDstScalarPerVector_K1, |
| 47 | + bool BBlockLdsAddExtraN, |
| 48 | + ck::index_t CShuffleMXdlPerWavePerShuffle, |
| 49 | + ck::index_t CShuffleNXdlPerWavePerShuffle, |
| 50 | + typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, |
| 51 | + ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl, |
| 52 | + typename ComputeTypeA, |
| 53 | + typename ComputeTypeB, |
| 54 | + ck::index_t MaxTransposeTransferSrcScalarPerVector, |
| 55 | + ck::index_t MaxTransposeTransferDstScalarPerVector> |
| 56 | +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle; |
| 57 | + |
| 58 | +} // namespace ck::tensor_operation::device |
| 59 | + |
| 60 | +namespace ck_tile { |
| 61 | +namespace reflect { |
| 62 | + |
| 63 | +template <ck::index_t NDimSpatial, |
| 64 | + typename InLayout_, |
| 65 | + typename WeiLayout_, |
| 66 | + typename OutLayout_, |
| 67 | + typename InDataType_, |
| 68 | + typename WeiDataType_, |
| 69 | + typename OutDataType_, |
| 70 | + typename AccDataType_, |
| 71 | + typename InElementwiseOperation_, |
| 72 | + typename WeiElementwiseOperation_, |
| 73 | + typename OutElementwiseOperation_, |
| 74 | + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization |
| 75 | + ConvBackwardWeightSpecialization, |
| 76 | + ck::index_t BlockSize, |
| 77 | + ck::index_t MPerBlock, |
| 78 | + ck::index_t NPerBlock, |
| 79 | + ck::index_t K0PerBlock, |
| 80 | + ck::index_t K1, |
| 81 | + ck::index_t MPerXDL, |
| 82 | + ck::index_t NPerXDL, |
| 83 | + ck::index_t MXdlPerWave, |
| 84 | + ck::index_t NXdlPerWave, |
| 85 | + typename ABlockTransferThreadClusterLengths_K0_M_K1_, |
| 86 | + typename ABlockTransferThreadClusterArrangeOrder_, |
| 87 | + typename ABlockTransferSrcAccessOrder_, |
| 88 | + ck::index_t ABlockTransferSrcVectorDim, |
| 89 | + ck::index_t ABlockTransferSrcScalarPerVector, |
| 90 | + ck::index_t ABlockTransferDstScalarPerVector_K1, |
| 91 | + bool ABlockLdsAddExtraM, |
| 92 | + typename BBlockTransferThreadClusterLengths_K0_N_K1_, |
| 93 | + typename BBlockTransferThreadClusterArrangeOrder_, |
| 94 | + typename BBlockTransferSrcAccessOrder_, |
| 95 | + ck::index_t BBlockTransferSrcVectorDim, |
| 96 | + ck::index_t BBlockTransferSrcScalarPerVector, |
| 97 | + ck::index_t BBlockTransferDstScalarPerVector_K1, |
| 98 | + bool BBlockLdsAddExtraN, |
| 99 | + ck::index_t CShuffleMXdlPerWavePerShuffle, |
| 100 | + ck::index_t CShuffleNXdlPerWavePerShuffle, |
| 101 | + typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, |
| 102 | + ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl, |
| 103 | + typename ComputeTypeA_, |
| 104 | + typename ComputeTypeB_, |
| 105 | + ck::index_t MaxTransposeTransferSrcScalarPerVector, |
| 106 | + ck::index_t MaxTransposeTransferDstScalarPerVector> |
| 107 | +struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< |
| 108 | + NDimSpatial, |
| 109 | + InLayout_, |
| 110 | + WeiLayout_, |
| 111 | + OutLayout_, |
| 112 | + InDataType_, |
| 113 | + WeiDataType_, |
| 114 | + OutDataType_, |
| 115 | + AccDataType_, |
| 116 | + InElementwiseOperation_, |
| 117 | + WeiElementwiseOperation_, |
| 118 | + OutElementwiseOperation_, |
| 119 | + ConvBackwardWeightSpecialization, |
| 120 | + BlockSize, |
| 121 | + MPerBlock, |
| 122 | + NPerBlock, |
| 123 | + K0PerBlock, |
| 124 | + K1, |
| 125 | + MPerXDL, |
| 126 | + NPerXDL, |
| 127 | + MXdlPerWave, |
| 128 | + NXdlPerWave, |
| 129 | + ABlockTransferThreadClusterLengths_K0_M_K1_, |
| 130 | + ABlockTransferThreadClusterArrangeOrder_, |
| 131 | + ABlockTransferSrcAccessOrder_, |
| 132 | + ABlockTransferSrcVectorDim, |
| 133 | + ABlockTransferSrcScalarPerVector, |
| 134 | + ABlockTransferDstScalarPerVector_K1, |
| 135 | + ABlockLdsAddExtraM, |
| 136 | + BBlockTransferThreadClusterLengths_K0_N_K1_, |
| 137 | + BBlockTransferThreadClusterArrangeOrder_, |
| 138 | + BBlockTransferSrcAccessOrder_, |
| 139 | + BBlockTransferSrcVectorDim, |
| 140 | + BBlockTransferSrcScalarPerVector, |
| 141 | + BBlockTransferDstScalarPerVector_K1, |
| 142 | + BBlockLdsAddExtraN, |
| 143 | + CShuffleMXdlPerWavePerShuffle, |
| 144 | + CShuffleNXdlPerWavePerShuffle, |
| 145 | + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, |
| 146 | + CBlockTransferScalarPerVector_NWaveNPerXdl, |
| 147 | + ComputeTypeA_, |
| 148 | + ComputeTypeB_, |
| 149 | + MaxTransposeTransferSrcScalarPerVector, |
| 150 | + MaxTransposeTransferDstScalarPerVector>> |
| 151 | +{ |
| 152 | + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; |
| 153 | + |
| 154 | + static constexpr ck::index_t kNDimSpatial = NDimSpatial; |
| 155 | + |
| 156 | + using InLayout = InLayout_; |
| 157 | + using WeiLayout = WeiLayout_; |
| 158 | + using OutLayout = OutLayout_; |
| 159 | + |
| 160 | + using InDataType = InDataType_; |
| 161 | + using WeiDataType = WeiDataType_; |
| 162 | + using OutDataType = OutDataType_; |
| 163 | + using AccDataType = AccDataType_; |
| 164 | + |
| 165 | + using InElementwiseOperation = InElementwiseOperation_; |
| 166 | + using WeiElementwiseOperation = WeiElementwiseOperation_; |
| 167 | + using OutElementwiseOperation = OutElementwiseOperation_; |
| 168 | + |
| 169 | + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; |
| 170 | + |
| 171 | + static constexpr ck::index_t kBlockSize = BlockSize; |
| 172 | + static constexpr ck::index_t kMPerBlock = MPerBlock; |
| 173 | + static constexpr ck::index_t kNPerBlock = NPerBlock; |
| 174 | + static constexpr ck::index_t kK0PerBlock = K0PerBlock; |
| 175 | + static constexpr ck::index_t kK1 = K1; |
| 176 | + static constexpr ck::index_t kMPerXDL = MPerXDL; |
| 177 | + static constexpr ck::index_t kNPerXDL = NPerXDL; |
| 178 | + static constexpr ck::index_t kMXdlPerWave = MXdlPerWave; |
| 179 | + static constexpr ck::index_t kNXdlPerWave = NXdlPerWave; |
| 180 | + |
| 181 | + using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; |
| 182 | + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; |
| 183 | + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; |
| 184 | + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; |
| 185 | + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = |
| 186 | + ABlockTransferSrcScalarPerVector; |
| 187 | + static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = |
| 188 | + ABlockTransferDstScalarPerVector_K1; |
| 189 | + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; |
| 190 | + |
| 191 | + using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; |
| 192 | + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; |
| 193 | + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; |
| 194 | + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; |
| 195 | + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = |
| 196 | + BBlockTransferSrcScalarPerVector; |
| 197 | + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = |
| 198 | + BBlockTransferDstScalarPerVector_K1; |
| 199 | + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; |
| 200 | + |
| 201 | + static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; |
| 202 | + static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; |
| 203 | + |
| 204 | + using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = |
| 205 | + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; |
| 206 | + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = |
| 207 | + CBlockTransferScalarPerVector_NWaveNPerXdl; |
| 208 | + |
| 209 | + using ComputeTypeA = ComputeTypeA_; |
| 210 | + using ComputeTypeB = ComputeTypeB_; |
| 211 | + |
| 212 | + static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector = |
| 213 | + MaxTransposeTransferSrcScalarPerVector; |
| 214 | + static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector = |
| 215 | + MaxTransposeTransferDstScalarPerVector; |
| 216 | + |
| 217 | + // Static member function to generate instance string |
| 218 | + static std::string instance_string() |
| 219 | + { |
| 220 | + std::ostringstream oss; |
| 221 | + |
| 222 | + // Kernel type name |
| 223 | + oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; |
| 224 | + |
| 225 | + // Template parameters in exact order |
| 226 | + oss << "<" << kNDimSpatial; // 1. NDimSpatial |
| 227 | + oss << "," << detail::layout_name<InLayout>(); // 2. InLayout |
| 228 | + oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout |
| 229 | + oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout |
| 230 | + oss << "," << detail::type_name<InDataType>(); // 5. InDataType |
| 231 | + oss << "," << detail::type_name<WeiDataType>(); // 6. WeiDataType |
| 232 | + oss << "," << detail::type_name<OutDataType>(); // 7. OutDataType |
| 233 | + oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType |
| 234 | + oss << "," |
| 235 | + << detail::elementwise_op_name<InElementwiseOperation>(); // 9. InElementwiseOperation |
| 236 | + oss << "," |
| 237 | + << detail::elementwise_op_name<WeiElementwiseOperation>(); // 10. |
| 238 | + // WeiElementwiseOperation |
| 239 | + oss << "," |
| 240 | + << detail::elementwise_op_name<OutElementwiseOperation>(); // 11. |
| 241 | + // OutElementwiseOperation |
| 242 | + oss << "," |
| 243 | + << detail::conv_bwd_weight_spec_name( |
| 244 | + kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization |
| 245 | + oss << "," << kBlockSize; // 13. BlockSize |
| 246 | + oss << "," << kMPerBlock; // 14. MPerBlock |
| 247 | + oss << "," << kNPerBlock; // 15. NPerBlock |
| 248 | + oss << "," << kK0PerBlock; // 16. K0PerBlock |
| 249 | + oss << "," << kK1; // 17. K1 |
| 250 | + oss << "," << kMPerXDL; // 18. MPerXDL |
| 251 | + oss << "," << kNPerXDL; // 19. NPerXDL |
| 252 | + oss << "," << kMXdlPerWave; // 20. MXdlPerWave |
| 253 | + oss << "," << kNXdlPerWave; // 21. NXdlPerWave |
| 254 | + oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22. |
| 255 | + oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23. |
| 256 | + oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24. |
| 257 | + oss << "," << kABlockTransferSrcVectorDim; // 25. |
| 258 | + oss << "," << kABlockTransferSrcScalarPerVector; // 26. |
| 259 | + oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. |
| 260 | + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. |
| 261 | + oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29. |
| 262 | + oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30. |
| 263 | + oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31. |
| 264 | + oss << "," << kBBlockTransferSrcVectorDim; // 32. |
| 265 | + oss << "," << kBBlockTransferSrcScalarPerVector; // 33. |
| 266 | + oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. |
| 267 | + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. |
| 268 | + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. |
| 269 | + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. |
| 270 | + oss << "," |
| 271 | + << detail::sequence_name< |
| 272 | + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. |
| 273 | + oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39. |
| 274 | + oss << "," << detail::type_name<ComputeTypeA>(); // 40. |
| 275 | + oss << "," << detail::type_name<ComputeTypeB>(); // 41. |
| 276 | + oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 42. |
| 277 | + oss << "," << kMaxTransposeTransferDstScalarPerVector; // 43. |
| 278 | + oss << ">"; |
| 279 | + |
| 280 | + return oss.str(); |
| 281 | + } |
| 282 | +}; |
| 283 | + |
| 284 | +} // namespace reflect |
| 285 | +} // namespace ck_tile |
0 commit comments