From e5eb4edd1acbe0b03609a99cba72a729df535f26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Wed, 29 Oct 2025 13:29:48 +0000 Subject: [PATCH 01/12] Add device operation to conv signature. Use unions to hold conv layouts and device operations. --- .../include/ck_tile/builder/conv_factory.hpp | 26 ++++++++- .../builder/conv_signature_concepts.hpp | 12 +++- .../builder/include/ck_tile/builder/types.hpp | 56 +++++++++++++++++++ .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 12 ++-- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 6 +- .../test/impl/conv_signature_types.hpp | 9 +-- .../test/utils/ckb_conv_test_common.hpp | 2 +- 12 files changed, 127 insertions(+), 26 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index de8ba4f648..252d423716 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -158,6 +158,28 @@ struct ConvTensorLayouts +consteval auto GetTensorLayout() +{ + + if constexpr(SPATIAL_DIM == 1) + { + return factory_internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 2) + { + return factory_internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 3) + { + return factory_internal::ConvTensorLayouts{}; + } + else + { + static_assert(false, "Unsupported spatial dimension for convolution layout."); + } +} + // Type mappings from builder convolution data type to CK tensor types. template struct ConvTensorTypes @@ -440,8 +462,8 @@ template { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = - factory_internal::ConvTensorLayouts; + /*static constexpr auto*/ + using Layouts = decltype(factory_internal::GetTensorLayout()); using Types = factory_internal::ConvTensorTypes; using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 0851f0061e..76e5590ad6 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -17,6 +17,7 @@ // signature at compile time. #pragma once +#include #include #include @@ -40,16 +41,21 @@ template concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); +template +concept ConvDeviceOp = std::same_as, GroupConvDeviceOp>; + +template +concept ConvLayout = std::same_as, GroupConvLayout>; + // Concept for a type that defines a convolution's operational signature. template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; { t.direction } -> std::convertible_to; - requires std::convertible_to || - std::convertible_to || - std::convertible_to; + { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; { t.elementwise_operation } -> std::convertible_to; + { t.device_operation } -> ConvDeviceOp; }; // Concept to validate a convolution signature's values. diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 7f49e77f81..509f240edd 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -48,6 +48,18 @@ enum class GroupConvLayout3D NGCDHW_GKCZYX_NGKDHW, }; +struct GroupConvLayout { + union { + GroupConvLayout1D _1d; + GroupConvLayout2D _2d; + GroupConvLayout3D _3d; + }; + + constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {} + constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {} + constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {} +}; + // Direction of the convolution operation. enum class ConvDirection { @@ -56,6 +68,50 @@ enum class ConvDirection BACKWARD_WEIGHT }; +// Forward convolution device operations. +enum class FwdGroupConvDeviceOperation +{ + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +}; + +// Backward data convolution device operations. +enum class BwdDataGroupConvDeviceOperation +{ + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, + DeviceGroupedConvBwdDataMultipleD, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle +}; + +// Backward weight convolution device operations. +enum class BwdWeightGroupConvDeviceOperation +{ + DeviceGroupedConvBwdWeight, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, + DeviceGroupedConvBwdWeight_Xdl_CShuffle, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, + DeviceGroupedConvBwdWeightMultipleD, + DeviceGroupedConvBwdWeight_Dl +}; + +// Structural type for device operation +struct GroupConvDeviceOp { + union { + FwdGroupConvDeviceOperation _fwd; + BwdDataGroupConvDeviceOperation _bwd_data; + BwdWeightGroupConvDeviceOperation _bwd_weight; + }; + + constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {} + constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {} + constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {} +}; + // Fused element-wise operations. enum class ElementwiseOperation { diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index d5b8802896..b660fa3303 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -9,12 +9,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 1, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::SCALE}; + .elementwise_operation = ElementwiseOperation::SCALE, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 77c5c80489..cf942f56a1 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -28,12 +30,14 @@ TEST(FwdConvInstances, TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index c81d7543bb..efd3ecc680 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -7,12 +7,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index d55a120bb8..a7248d25b5 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -7,12 +7,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index f7bcf49e54..b8c8bc7063 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 27b5ddc821..035a9df36d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index c0b6f04383..2713dd1b01 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index 297f827395..cc5490c711 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -3,11 +3,13 @@ #pragma once +#include #include "ck_tile/builder/conv_signature_concepts.hpp" namespace ck_tile::builder::test { -template +using namespace ck_tile::builder; + struct ConvSignature { int spatial_dim; @@ -15,9 +17,8 @@ struct ConvSignature GroupConvLayout layout; DataType data_type; ElementwiseOperation elementwise_operation; + GroupConvDeviceOp device_operation; }; -static_assert(ConvSignatureDescriptor>); -static_assert(ConvSignatureDescriptor>); -static_assert(ConvSignatureDescriptor>); +static_assert(ConvSignatureDescriptor); } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index 7ad01bd922..cd3943d26f 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -11,7 +11,7 @@ using namespace ck_tile::builder; using namespace test; // Common test implementation -template From 74ba32ea58bfa04478106e4430d63b34eddb94a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Wed, 29 Oct 2025 14:39:03 +0000 Subject: [PATCH 02/12] Add predicates for all device op instances. --- .../include/ck_tile/builder/conv_factory.hpp | 7 +- .../builder/conv_signature_concepts.hpp | 15 +- .../builder/conv_signature_predicates.hpp | 162 ++++++++++++++++++ .../builder/include/ck_tile/builder/types.hpp | 12 +- 4 files changed, 174 insertions(+), 22 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 252d423716..573577f2ee 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -454,15 +454,16 @@ template struct ConvFactory; -// Factory specialization for an instance of a grouped forward convolution kernel. +// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance +// of a grouped forward convolution kernel. template - requires ConvDirectionIsForward + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - /*static constexpr auto*/ using Layouts = decltype(factory_internal::GetTensorLayout()); using Types = factory_internal::ConvTensorTypes; using Ops = factory_internal::ElementwiseOps; diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 76e5590ad6..7864cde1ae 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -17,11 +17,11 @@ // signature at compile time. #pragma once -#include #include #include #include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/conv_signature_predicates.hpp" namespace ck_tile::builder { @@ -63,18 +63,7 @@ template concept ValidConvSignature = requires { requires ConvSpatialDim; requires ConvDataType; + //requires ConvDeviceOp; }; -// Predicate for forward convolution. -template -concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); - -// Predicate for backward data convolution. -template -concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); - -// Predicate for backward weight convolution. -template -concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp new file mode 100644 index 0000000000..9b47d87329 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + + +#include +#include + +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder { + +/********************************************** + * Conv Direction Predicates + **********************************************/ + +// Predicate for forward convolution. +template +concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); + +// Predicate for backward data convolution. +template +concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); + +// Predicate for backward weight convolution. +template +concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); + +/********************************************** + * Conv Fwd Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); + +// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); + +// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); + +// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); + +// Generic predicate to check if signature uses any forward convolution device operation. +template +concept ConvDeviceOpIsForward = + ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + + +/********************************************** + * Conv Bwd Weight Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvBwdWeight operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); + +// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); + +// Predicate for DeviceGroupedConvBwdWeightMultipleD operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); + +// Predicate for DeviceGroupedConvBwdWeight_Dl operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); + +// Generic predicate to check if signature uses any backward weight convolution device operation. +template +concept ConvDeviceOpIsBackwardWeight = + ConvDeviceOpIs_DeviceGroupedConvBwdWeight || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl; + +/********************************************** + * Conv Bwd Data Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = + (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); + +// Predicate for DeviceGroupedConvBwdDataMultipleD operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = + (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); + +// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); + +// Generic predicate to check if signature uses any backward data convolution device operation. +template +concept ConvDeviceOpIsBackwardData = + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 || + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD || + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + +/********************************************** + * Generic Device Op Predicates + **********************************************/ + +// Generic predicate to check if signature uses any device operation. +template +concept IsValidConvDeviceOp = + ConvDeviceOpIsForward || + ConvDeviceOpIsBackwardData || + ConvDeviceOpIsBackwardWeight; + +} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 509f240edd..7c0e23abde 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -81,22 +81,22 @@ enum class FwdGroupConvDeviceOperation // Backward data convolution device operations. enum class BwdDataGroupConvDeviceOperation { - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, DeviceGroupedConvBwdDataMultipleD, - DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 }; // Backward weight convolution device operations. enum class BwdWeightGroupConvDeviceOperation { DeviceGroupedConvBwdWeight, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, + DeviceGroupedConvBwdWeight_Dl, DeviceGroupedConvBwdWeight_Xdl_CShuffle, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, DeviceGroupedConvBwdWeightMultipleD, - DeviceGroupedConvBwdWeight_Dl + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, }; // Structural type for device operation From fbdded692701dba337150e919e8a48cd6f11379b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Wed, 29 Oct 2025 14:44:29 +0000 Subject: [PATCH 03/12] Use the device op signature for validation. --- .../builder/include/ck_tile/builder/conv_signature_concepts.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 7864cde1ae..370e7b6521 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -63,7 +63,7 @@ template concept ValidConvSignature = requires { requires ConvSpatialDim; requires ConvDataType; - //requires ConvDeviceOp; + requires IsValidConvDeviceOp; }; } // namespace ck_tile::builder From ee1398250d17aab531f3c252fc74a0f1f8c4873f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Thu, 30 Oct 2025 07:09:30 +0000 Subject: [PATCH 04/12] Fix ckb CMakeLists.txt file for tests. --- experimental/builder/test/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index b7adbc116a..c53ce6210a 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -19,7 +19,7 @@ endfunction() # The test_conv_builder target has all the unit tests (each test should run < 10 ms) add_ck_builder_test(test_conv_builder test_conv_builder.cpp - test_instance_traits.cpp + test_fwd_instance_traits.cpp test_instance_traits_util.cpp) add_ck_builder_test(test_inline_diff test_inline_diff.cpp) From 28e0d5f82f9d4a00c4c7fcc4d4678836d0a4e75a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Thu, 30 Oct 2025 07:55:11 +0000 Subject: [PATCH 05/12] Fix building CK Builder instance traits after the introduction of direct load template parameter in CK. --- ...e_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 21201b8d50..9ab827e3a5 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -69,7 +69,8 @@ template + typename BComputeDataType, + bool DirectLoad> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; } // namespace ck::tensor_operation::device @@ -124,7 +125,8 @@ template + typename BComputeDataType_, + bool DirectLoad> struct InstanceTraits> + BComputeDataType_, + DirectLoad>> { // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; @@ -336,6 +339,7 @@ struct InstanceTraits(); // 47. AComputeDataType oss << "," << detail::type_name(); // 48. BComputeDataType + oss << "," << (DirectLoad ? "true" : "false"); // 49. DirectLoad oss << ">"; return oss.str(); From c8eac6f094ad270aef234d174475d9c4e8ac3a91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Thu, 30 Oct 2025 09:01:11 +0000 Subject: [PATCH 06/12] Fix clang-formatting. --- .../include/ck_tile/builder/conv_factory.hpp | 8 ++- .../builder/conv_signature_predicates.hpp | 62 +++++++++++-------- .../builder/include/ck_tile/builder/types.hpp | 16 +++-- .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 8 +-- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 4 +- 10 files changed, 68 insertions(+), 50 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 573577f2ee..31be8c322c 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -454,17 +454,19 @@ template struct ConvFactory; -// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance +// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance // of a grouped forward convolution kernel. template - requires ConvDirectionIsForward && + requires ConvDirectionIsForward && ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = decltype(factory_internal::GetTensorLayout()); + using Layouts = decltype(factory_internal::GetTensorLayout()); using Types = factory_internal::ConvTensorTypes; using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp index 9b47d87329..f947c7e329 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -3,7 +3,6 @@ #pragma once - #include #include @@ -34,38 +33,42 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); // Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = - (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); // Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); // Generic predicate to check if signature uses any forward convolution device operation. template -concept ConvDeviceOpIsForward = +concept ConvDeviceOpIsForward = ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK || ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle || ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle || ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 || ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; - /********************************************** * Conv Bwd Weight Device Op Predicates **********************************************/ @@ -73,46 +76,54 @@ concept ConvDeviceOpIsForward = // Predicate for DeviceGroupedConvBwdWeight operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = - (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); // Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = - (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = - (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = - (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = - (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = - (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); // Predicate for DeviceGroupedConvBwdWeightMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = - (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); // Predicate for DeviceGroupedConvBwdWeight_Dl operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = - (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); // Generic predicate to check if signature uses any backward weight convolution device operation. template -concept ConvDeviceOpIsBackwardWeight = +concept ConvDeviceOpIsBackwardWeight = ConvDeviceOpIs_DeviceGroupedConvBwdWeight || ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle || ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle || @@ -129,21 +140,24 @@ concept ConvDeviceOpIsBackwardWeight = // Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = - (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); // Predicate for DeviceGroupedConvBwdDataMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = - (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); // Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = - (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); // Generic predicate to check if signature uses any backward data convolution device operation. template -concept ConvDeviceOpIsBackwardData = +concept ConvDeviceOpIsBackwardData = ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 || ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD || ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; @@ -154,9 +168,7 @@ concept ConvDeviceOpIsBackwardData = // Generic predicate to check if signature uses any device operation. template -concept IsValidConvDeviceOp = - ConvDeviceOpIsForward || - ConvDeviceOpIsBackwardData || - ConvDeviceOpIsBackwardWeight; +concept IsValidConvDeviceOp = ConvDeviceOpIsForward || ConvDeviceOpIsBackwardData || + ConvDeviceOpIsBackwardWeight; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 7c0e23abde..47bd8327d4 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -48,13 +48,15 @@ enum class GroupConvLayout3D NGCDHW_GKCZYX_NGKDHW, }; -struct GroupConvLayout { - union { +struct GroupConvLayout +{ + union + { GroupConvLayout1D _1d; GroupConvLayout2D _2d; GroupConvLayout3D _3d; }; - + constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {} constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {} constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {} @@ -100,13 +102,15 @@ enum class BwdWeightGroupConvDeviceOperation }; // Structural type for device operation -struct GroupConvDeviceOp { - union { +struct GroupConvDeviceOp +{ + union + { FwdGroupConvDeviceOperation _fwd; BwdDataGroupConvDeviceOperation _bwd_data; BwdWeightGroupConvDeviceOperation _bwd_weight; }; - + constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {} constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {} constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {} diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index b660fa3303..77ff0fe28f 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -15,8 +15,8 @@ TEST(FwdConvInstances, .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, .data_type = DataType::BF16, .elementwise_operation = ElementwiseOperation::SCALE, - .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 - }; + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index cf942f56a1..5be7d5e604 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -14,8 +14,8 @@ TEST(FwdConvInstances, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 - }; + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -36,8 +36,8 @@ TEST(FwdConvInstances, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 - }; + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index efd3ecc680..4abe3df40d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -13,8 +13,8 @@ TEST(FwdConvInstances, .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, .data_type = DataType::FP16, .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 - }; + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index a7248d25b5..5ea804cf8b 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -13,8 +13,8 @@ TEST(FwdConvInstances, .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, .data_type = DataType::FP32, .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 - }; + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index b8c8bc7063..c729148346 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -14,8 +14,8 @@ TEST(FwdConvInstances, .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, .data_type = DataType::BF16, .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 - }; + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 035a9df36d..832acd7412 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -14,8 +14,8 @@ TEST(FwdConvInstances, .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, .data_type = DataType::FP16, .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 - }; + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 2713dd1b01..9d0e107dbc 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -14,8 +14,8 @@ TEST(FwdConvInstances, .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, .data_type = DataType::FP32, .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 - }; + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; From 0b83a5745441f6b04062ed5d01bf61682737072b Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Tue, 4 Nov 2025 09:23:28 +0000 Subject: [PATCH 07/12] add device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk --- .../include/ck_tile/builder/conv_factory.hpp | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 8ea3e18d65..167b569ee5 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -36,6 +36,7 @@ #pragma once +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" @@ -990,4 +991,117 @@ struct ConvFactory GRIDWISE_GEMM_PIPELINE_VERSION>; }; +// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance +// of a grouped forward convolution kernel using Direct Load (DL) approach. +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + + // DL-specific hardcoded parameters based on the example from convnd_fwd_dl_fp16.cpp + // These can be made configurable in the future + static constexpr ck::index_t K0PerBlock = 16; + static constexpr ck::index_t K1 = 2; + static constexpr ck::index_t M1PerThread = 4; + static constexpr ck::index_t N1PerThread = 4; + static constexpr ck::index_t KPerThread = 1; + + // Thread cluster configuration + using M1N1ThreadClusterM1Xs = ck::Sequence<8, 2>; + using M1N1ThreadClusterN1Xs = ck::Sequence<8, 2>; + + // A Block Transfer - K0_M0_M1_K1 tensor format + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = ck::Sequence<8, 1, 1, 2>; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = ck::Sequence<2, 1, 128, 1>; + using ABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 2, 0, 3>; + using ABlockTransferSrcAccessOrder = ck::Sequence<1, 2, 0, 3>; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = ck::Sequence<4, 1, 1, 2>; + using ABlockTransferSrcVectorTensorContiguousDimOrder = ck::Sequence<1, 2, 0, 3>; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = ck::Sequence<1, 1, 1, 2>; + + // B Block Transfer - K0_N0_N1_K1 tensor format + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = ck::Sequence<8, 1, 1, 2>; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = ck::Sequence<2, 1, 128, 1>; + using BBlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 2, 0, 3>; + using BBlockTransferSrcAccessOrder = ck::Sequence<1, 2, 0, 3>; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = ck::Sequence<4, 1, 1, 2>; + using BBlockTransferSrcVectorTensorContiguousDimOrder = ck::Sequence<1, 2, 0, 3>; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = ck::Sequence<1, 1, 1, 2>; + + // C Thread Transfer + using CThreadTransferSrcDstAccessOrder = ck::Sequence<0, 1, 2, 3, 4, 5>; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4; + + // The DL forward convolution kernel class instance + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + SPATIAL_DIM, + typename Types::ADataType, + typename Types::BDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Types::AccDataType, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + FWD_CONV_SPECIALIZATION, + GEMM_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + K0PerBlock, + K1, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM1Xs, + M1N1ThreadClusterN1Xs, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; +}; + } // namespace ck_tile::builder From f207c900d0a480b82ca2a1a731614d333cfd0832 Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Tue, 4 Nov 2025 14:57:12 +0000 Subject: [PATCH 08/12] Add full DL configurability with Option A implementation - Added 5 DL descriptor structs (39 configurable parameters) - Added 10 C++20 concepts for type-safe validation - Updated factory to read all parameters from descriptors - Updated test helper to populate all descriptors - All tests passing (13/13 including 3 new DL tests) --- .../builder/conv_algorithm_concepts.hpp | 83 ++++++++++++++++ .../include/ck_tile/builder/conv_factory.hpp | 95 ++++++++++++------- experimental/builder/test/CMakeLists.txt | 1 + .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 69 ++++++++++++++ .../test/impl/conv_algorithm_types.hpp | 80 ++++++++++++++++ .../test/utils/ckb_conv_test_common.hpp | 71 ++++++++++++++ 6 files changed, 365 insertions(+), 34 deletions(-) create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 586a119c75..a298c32b02 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -183,4 +183,87 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +/******************************************** */ +/* DL-specific descriptors and requirements */ +/******************************************** */ + +// Concept for DL thread configuration +template +concept DlThreadConfigDescriptor = requires(T t) { + { t.k0_per_block } -> std::convertible_to; + { t.k1 } -> std::convertible_to; + { t.m1_per_thread } -> std::convertible_to; + { t.n1_per_thread } -> std::convertible_to; + { t.k_per_thread } -> std::convertible_to; +}; + +// Concept for DL thread cluster +template +concept DlThreadClusterDescriptor = requires(T t) { + { t.m1_xs } -> std::convertible_to>; + { t.n1_xs } -> std::convertible_to>; +}; + +// Concept for DL block transfer K0_M0_M1_K1 format +template +concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) { + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; +}; + +// Concept for DL block transfer K0_N0_N1_K1 format +template +concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) { + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; +}; + +// Concept for DL C thread transfer +template +concept DlCThreadTransferDescriptor = requires(T t) { + { t.src_dst_access_order } -> std::convertible_to>; + { t.src_dst_vector_dim } -> std::convertible_to; + { t.dst_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if algorithm specifies DL thread config +template +concept SpecifiesDlThreadConfig = requires { + { T::dl_thread_config } -> DlThreadConfigDescriptor; +}; + +// Concept to check if algorithm specifies DL thread cluster +template +concept SpecifiesDlThreadCluster = requires { + { T::dl_thread_cluster } -> DlThreadClusterDescriptor; +}; + +// Concept to check if algorithm specifies DL A block transfer +template +concept SpecifiesDlBlockTransferA = requires { + { T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor; +}; + +// Concept to check if algorithm specifies DL B block transfer +template +concept SpecifiesDlBlockTransferB = requires { + { T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor; +}; + +// Concept to check if algorithm specifies DL C thread transfer +template +concept SpecifiesDlCThreadTransfer = requires { + { T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 167b569ee5..8e14753d66 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -1015,6 +1015,16 @@ struct ConvFactory "specialization."); static_assert(SpecifiesGemmSpecialization, "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesDlThreadConfig, + "DL algorithm must specify thread config."); + static_assert(SpecifiesDlThreadCluster, + "DL algorithm must specify thread cluster."); + static_assert(SpecifiesDlBlockTransferA, + "DL algorithm must specify A block transfer."); + static_assert(SpecifiesDlBlockTransferB, + "DL algorithm must specify B block transfer."); + static_assert(SpecifiesDlCThreadTransfer, + "DL algorithm must specify C thread transfer."); static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); @@ -1023,40 +1033,57 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - // DL-specific hardcoded parameters based on the example from convnd_fwd_dl_fp16.cpp - // These can be made configurable in the future - static constexpr ck::index_t K0PerBlock = 16; - static constexpr ck::index_t K1 = 2; - static constexpr ck::index_t M1PerThread = 4; - static constexpr ck::index_t N1PerThread = 4; - static constexpr ck::index_t KPerThread = 1; - - // Thread cluster configuration - using M1N1ThreadClusterM1Xs = ck::Sequence<8, 2>; - using M1N1ThreadClusterN1Xs = ck::Sequence<8, 2>; - - // A Block Transfer - K0_M0_M1_K1 tensor format - using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = ck::Sequence<8, 1, 1, 2>; - using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = ck::Sequence<2, 1, 128, 1>; - using ABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 2, 0, 3>; - using ABlockTransferSrcAccessOrder = ck::Sequence<1, 2, 0, 3>; - using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = ck::Sequence<4, 1, 1, 2>; - using ABlockTransferSrcVectorTensorContiguousDimOrder = ck::Sequence<1, 2, 0, 3>; - using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = ck::Sequence<1, 1, 1, 2>; - - // B Block Transfer - K0_N0_N1_K1 tensor format - using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = ck::Sequence<8, 1, 1, 2>; - using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = ck::Sequence<2, 1, 128, 1>; - using BBlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 2, 0, 3>; - using BBlockTransferSrcAccessOrder = ck::Sequence<1, 2, 0, 3>; - using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = ck::Sequence<4, 1, 1, 2>; - using BBlockTransferSrcVectorTensorContiguousDimOrder = ck::Sequence<1, 2, 0, 3>; - using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = ck::Sequence<1, 1, 1, 2>; - - // C Thread Transfer - using CThreadTransferSrcDstAccessOrder = ck::Sequence<0, 1, 2, 3, 4, 5>; - static constexpr ck::index_t CThreadTransferSrcDstVectorDim = 5; - static constexpr ck::index_t CThreadTransferDstScalarPerVector = 4; + // DL-specific parameters from algorithm descriptor + static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config; + static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; + static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; + static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; + static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; + static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; + + // Thread cluster from descriptor + static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster; + using M1N1ThreadClusterM1Xs = to_sequence_v; + using M1N1ThreadClusterN1Xs = to_sequence_v; + + // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format + static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a; + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using ABlockTransferSrcAccessOrder = to_sequence_v; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + + // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format + static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b; + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using BBlockTransferSrcAccessOrder = to_sequence_v; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + + // C Thread Transfer from descriptor + static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer; + using CThreadTransferSrcDstAccessOrder = to_sequence_v; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = + DL_C_TRANSFER.dst_scalar_per_vector; // The DL forward convolution kernel class instance using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 26a666a805..d543255ad4 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -41,6 +41,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp + conv/test_ckb_conv_fwd_2d_dl_fp16.cpp conv/test_ckb_conv_fwd_3d_bf16.cpp conv/test_ckb_conv_fwd_3d_fp16.cpp conv/test_ckb_conv_fwd_3d_fp32.cpp) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp new file mode 100644 index 0000000000..7138e47545 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); +} + +TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); +} + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 9c5ca9b97b..921c7953e8 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -214,4 +214,84 @@ static_assert( static_assert( ckb::SpecifiesLoopScheduler); +// DL-specific descriptors +struct DlThreadConfig +{ + size_t k0_per_block; + size_t k1; + size_t m1_per_thread; + size_t n1_per_thread; + size_t k_per_thread; +}; +static_assert(ckb::DlThreadConfigDescriptor); + +struct DlThreadCluster +{ + std::array m1_xs; // e.g., {8, 2} + std::array n1_xs; // e.g., {8, 2} +}; +static_assert(ckb::DlThreadClusterDescriptor); + +struct DlBlockTransferK0M0M1K1 +{ + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; +}; +static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor); + +struct DlBlockTransferK0N0N1K1 +{ + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; +}; +static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor); + +struct DlCThreadTransfer +{ + std::array src_dst_access_order; + size_t src_dst_vector_dim; + size_t dst_scalar_per_vector; +}; +static_assert(ckb::DlCThreadTransferDescriptor); + +struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +{ + ThreadBlock thread_block; + ConvFwdSpecialization fwd_specialization; + GemmSpecialization gemm_specialization; + DlThreadConfig dl_thread_config; + DlThreadCluster dl_thread_cluster; + DlBlockTransferK0M0M1K1 dl_block_transfer_a; + DlBlockTransferK0N0N1K1 dl_block_transfer_b; + DlCThreadTransfer dl_c_thread_transfer; +}; +static_assert( + ckb::ConvAlgorithmDescriptor); +static_assert( + ckb::SpecifiesThreadBlock); +static_assert(ckb::SpecifiesFwdConcSpecialization< + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>); +static_assert( + ckb::SpecifiesGemmSpecialization); +static_assert( + ckb::SpecifiesDlThreadConfig); +static_assert( + ckb::SpecifiesDlThreadCluster); +static_assert( + ckb::SpecifiesDlBlockTransferA); +static_assert( + ckb::SpecifiesDlBlockTransferB); +static_assert( + ckb::SpecifiesDlCThreadTransfer); + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index d18a008015..f85c78bbdf 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -235,4 +235,75 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() EXPECT_NE(invoker_ptr, nullptr); } +template +constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK() +{ + // DL thread configuration + constexpr DlThreadConfig DlThreadCfg{ + .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + + // DL thread cluster + constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; + + // DL A block transfer - K0_M0_M1_K1 format + constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + + // DL B block transfer - K0_N0_N1_K1 format + constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + + // DL C thread transfer + constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}; + + constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .dl_thread_config = DlThreadCfg, + .dl_thread_cluster = DlCluster, + .dl_block_transfer_a = DlBlockTransferA, + .dl_block_transfer_b = DlBlockTransferB, + .dl_c_thread_transfer = DlCTransfer}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + } // namespace ck_tile::builder::test_utils From 6c778f6d3debac12c5b8b71cd96cf950cccf6ad2 Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Wed, 5 Nov 2025 08:23:30 +0000 Subject: [PATCH 09/12] Add factory and test support for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add factory specialization for Large_Tensor device operation (conv_factory.hpp lines 1145-1265) - Add macro collision workaround using pragma push/pop (conv_factory.hpp lines 43-51) - Add test helper function run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add builder test file test_ckb_conv_fwd_2d_large_tensor_fp16.cpp with 2 test cases - Update CMakeLists.txt to include new test file - Reuse existing ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle descriptor - Map all 42 template parameters identical to regular XDL CShuffle - All 15 builder tests passing including 2 new Large_Tensor tests Completes Task 350: All 4 forward convolution device operations now supported in CK Builder. --- .../include/ck_tile/builder/conv_factory.hpp | 130 ++++++++++++++++++ experimental/builder/test/CMakeLists.txt | 1 + ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 53 +++++++ .../test/utils/ckb_conv_test_common.hpp | 74 ++++++++++ 4 files changed, 258 insertions(+) create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 8e14753d66..4e80d83a36 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -40,6 +40,17 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +// WORKAROUND: Macro namespace collision in upstream CK device operation headers. +// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and +// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define +// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors. +// Use pragma push/pop to isolate the Large_Tensor header's macro scope. +#pragma push_macro("GridwiseGemmTemplateParameters") +#ifdef GridwiseGemmTemplateParameters +#undef GridwiseGemmTemplateParameters +#endif +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#pragma pop_macro("GridwiseGemmTemplateParameters") #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" @@ -1131,4 +1142,123 @@ struct ConvFactory CThreadTransferDstScalarPerVector>; }; +// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance +// of a grouped forward convolution kernel with large tensor support (N-splitting). +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesGridwiseXdlGemm, + "The convolution algorithm descriptor must specify gridwise GEMM info."); + static_assert(SpecifiesBlockTransfer, + "The convolution algorithm descriptor must specify block transfer info."); + static_assert(SpecifiesLdsTransfer, + "The convolution algorithm descriptor must specify LDS transfer info."); + static_assert( + SpecifiesThreadClusterAccessOrder, + "The convolution algorithm descriptor must specify thread cluster access order info."); + static_assert(SpecifiesSourceAccessOrder, + "The convolution algorithm descriptor must specify source access order info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesNumPrefetchStages, + "The convolution algorithm descriptor must specify number of prefetch stages."); + static_assert(SpecifiesLoopScheduler, + "The convolution algorithm descriptor must specify loop scheduler."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + factory_internal::SetFwdConvABlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + factory_internal::SetFwdConvBBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = + factory_internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance with large tensor support. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::AComputeType, + typename Types::BComputeType, + LOOP_SCHEDULER>; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index d543255ad4..9bc10117dd 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -42,6 +42,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp conv/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp conv/test_ckb_conv_fwd_3d_bf16.cpp conv/test_ckb_conv_fwd_3d_fp16.cpp conv/test_ckb_conv_fwd_3d_fp32.cpp) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp new file mode 100644 index 0000000000..333ead2aaf --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::DEFAULT>(); +} + +TEST( + FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index f85c78bbdf..0e66428fdb 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -306,4 +306,78 @@ constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK() EXPECT_NE(invoker_ptr, nullptr); } +// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// Note: Large_Tensor has identical parameters to regular XDL CShuffle +template +constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor() +{ + constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 32, + .n_per_xdl = 32, + .m_xdl_per_wave = 2, + .n_xdl_per_wave = 1}; + + constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 16, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + + // Large_Tensor uses the same descriptor as regular XDL CShuffle + constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .gridwise_gemm = FwdGemmParams, + .block_transfer = FwdBlockTransfer, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .num_gemm_k_prefetch_stages = 1, + .num_groups_to_merge = 1, + .loop_scheduler = LoopScheduler::DEFAULT}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE( + kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + } // namespace ck_tile::builder::test_utils From 43f104a76d4387bad840383fc59aa15420170ad9 Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Thu, 6 Nov 2025 08:28:45 +0000 Subject: [PATCH 10/12] Update copyright headers to new format - Change copyright format to: Copyright (C) Advanced Micro Devices, Inc., or its affiliates. - Reorder headers: Copyright first, then SPDX-License-Identifier - Updated files: * experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp * experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp * experimental/builder/include/ck_tile/builder/device_op_types.hpp --- .../ck_tile/builder/device_op_types.hpp | 21 +++++++++++++++++++ .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 2 +- ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 2 +- 3 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/device_op_types.hpp diff --git a/experimental/builder/include/ck_tile/builder/device_op_types.hpp b/experimental/builder/include/ck_tile/builder/device_op_types.hpp new file mode 100644 index 0000000000..b925c3d466 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/device_op_types.hpp @@ -0,0 +1,21 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::builder { + +// Enumeration for CK Device Operation types. +// This allows the builder to select which device operation template to instantiate +// based on the user's requirements. +enum class DeviceOpType +{ + // Forward Convolution - Non-grouped + CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet) + + // Forward Convolution - Grouped + GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +}; + +} // namespace ck_tile::builder diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp index 7138e47545..12730bab19 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -1,5 +1,5 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "utils/ckb_conv_test_common.hpp" diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 333ead2aaf..0216c5907d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -1,5 +1,5 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "utils/ckb_conv_test_common.hpp" From 9e334bab60a1512ccb324302a9ce71a2ae8a31c2 Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Thu, 6 Nov 2025 15:11:57 +0000 Subject: [PATCH 11/12] fix c++ 18 format --- .../builder/include/ck_tile/builder/device_op_types.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/experimental/builder/include/ck_tile/builder/device_op_types.hpp b/experimental/builder/include/ck_tile/builder/device_op_types.hpp index b925c3d466..1c8defee0c 100644 --- a/experimental/builder/include/ck_tile/builder/device_op_types.hpp +++ b/experimental/builder/include/ck_tile/builder/device_op_types.hpp @@ -15,7 +15,8 @@ enum class DeviceOpType // Forward Convolution - Grouped GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle - GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: + // DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 }; } // namespace ck_tile::builder From 6a230ae8112cc64f2073a3c2836a9fd7a4cd1183 Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Thu, 6 Nov 2025 15:18:26 +0000 Subject: [PATCH 12/12] Fix clang-format-18 error in device_op_types.hpp --- .../builder/include/ck_tile/builder/device_op_types.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/include/ck_tile/builder/device_op_types.hpp b/experimental/builder/include/ck_tile/builder/device_op_types.hpp index 1c8defee0c..0e779fdf4e 100644 --- a/experimental/builder/include/ck_tile/builder/device_op_types.hpp +++ b/experimental/builder/include/ck_tile/builder/device_op_types.hpp @@ -16,7 +16,7 @@ enum class DeviceOpType // Forward Convolution - Grouped GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: - // DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + // DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 }; } // namespace ck_tile::builder