Skip to content

Commit 8681ced

Browse files
authored
[CK TILE] Refactor Conv configs and Conv Elementwise (#3151)
* [CK TILE] Refactor Conv configs and Conv Elementwise * fix
1 parent 99f38e4 commit 8681ced

14 files changed

+236
-225
lines changed

example/ck_tile/20_grouped_convolution/conv_configs.hpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@ struct ConvConfigBase
1818
static constexpr bool kPadN = true;
1919
static constexpr bool kPadK = true;
2020

21-
static constexpr bool PermuteA = false;
22-
static constexpr bool PermuteB = false;
23-
24-
static constexpr bool TransposeC = false;
25-
static constexpr bool UseStructuredSparsity = false;
21+
static constexpr bool TransposeC = false;
2622

2723
static constexpr ck_tile::index_t VectorSizeA = 4;
2824
static constexpr ck_tile::index_t VectorSizeB = 8;
@@ -34,8 +30,6 @@ struct ConvConfigBase
3430
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
3531
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
3632
static constexpr ck_tile::index_t NumWaveGroups = 1;
37-
static constexpr bool Preshuffle = false;
38-
static constexpr bool TiledMMAPermuteN = false;
3933

4034
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
4135
};

example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "grouped_convolution_backward_data_invoker.hpp"
1515
#include "run_grouped_convolution_bwd_data_example.inc"
1616

17-
template <template <typename PrecType> typename GemmConfig>
17+
template <template <typename PrecType> typename ConvConfig>
1818
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
1919
{
2020
using Invoker = GroupedConvolutionBackwardDataInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
3131
if(data_type == "fp16")
3232
{
3333
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
34-
GemmConfig<ck_tile::half_t>,
34+
ConvConfig<ck_tile::half_t>,
3535
ck_tile::half_t>(
3636
in_layout, wei_layout, out_layout, argc, argv);
3737
}
3838
else if(data_type == "bf16")
3939
{
4040
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
41-
GemmConfig<ck_tile::bf16_t>,
41+
ConvConfig<ck_tile::bf16_t>,
4242
ck_tile::bf16_t>(
4343
in_layout, wei_layout, out_layout, argc, argv);
4444
}

example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct GroupedConvolutionBackwardDataInvoker
88
{
99

1010
template <ck_tile::index_t NDimSpatial,
11-
typename GemmConfig,
11+
typename ConvConfig,
1212
typename InDataType,
1313
typename WeiDataType,
1414
typename AccDataType,
@@ -26,12 +26,11 @@ struct GroupedConvolutionBackwardDataInvoker
2626

2727
// Implicit GEMM Traits
2828
using GemmShape = ck_tile::TileGemmShape<
29-
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
30-
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
31-
ck_tile::
32-
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
33-
GemmConfig::PermuteA,
34-
GemmConfig::PermuteB>;
29+
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
30+
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
31+
ck_tile::sequence<ConvConfig::M_Warp_Tile,
32+
ConvConfig::N_Warp_Tile,
33+
ConvConfig::K_Warp_Tile>>;
3534

3635
constexpr ck_tile::index_t VectorSizeA = 8;
3736
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -40,8 +39,8 @@ struct GroupedConvolutionBackwardDataInvoker
4039
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
4140
using TilePartitioner =
4241
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
43-
GemmConfig::TileParitionerGroupNum,
44-
GemmConfig::TileParitionerM01>;
42+
ConvConfig::TileParitionerGroupNum,
43+
ConvConfig::TileParitionerM01>;
4544
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
4645
ConvSpec,
4746
InLayout,
@@ -53,17 +52,17 @@ struct GroupedConvolutionBackwardDataInvoker
5352
VectorSizeC>;
5453

5554
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
56-
GemmConfig::kPadM,
57-
GemmConfig::kPadN,
58-
GemmConfig::kPadK,
59-
GemmConfig::DoubleSmemBuffer,
55+
ConvConfig::kPadM,
56+
ConvConfig::kPadN,
57+
ConvConfig::kPadK,
58+
ConvConfig::DoubleSmemBuffer,
6059
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
6160
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
6261
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
63-
GemmConfig::TransposeC,
64-
GemmConfig::UseStructuredSparsity,
62+
ConvConfig::TransposeC,
63+
false,
6564
false, // Persistent,
66-
GemmConfig::NumWaveGroups>;
65+
ConvConfig::NumWaveGroups>;
6766

6867
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
6968
OutDataType,
@@ -79,16 +78,16 @@ struct GroupedConvolutionBackwardDataInvoker
7978
VectorSizeB>;
8079

8180
using BaseGemmPipeline = typename PipelineTypeTraits<
82-
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
81+
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
8382

8483
const ck_tile::index_t gemm_k =
8584
args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(),
8685
args.filter_spatial_lengths_.end(),
8786
1,
8887
std::multiplies<ck_tile::index_t>());
8988

90-
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
91-
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
89+
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
90+
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
9291
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
9392
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
9493
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -98,7 +97,7 @@ struct GroupedConvolutionBackwardDataInvoker
9897
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
9998
constexpr bool has_hot_loop_v = has_hot_loop_.value;
10099
constexpr auto tail_number_v = tail_number_.value;
101-
constexpr auto scheduler = GemmConfig::Scheduler;
100+
constexpr auto scheduler = ConvConfig::Scheduler;
102101
constexpr auto memory_operation = memory_operation_.value;
103102

104103
using UniversalGemmProblem =
@@ -118,7 +117,7 @@ struct GroupedConvolutionBackwardDataInvoker
118117
VectorSizeB>;
119118

120119
using GemmPipeline = typename PipelineTypeTraits<
121-
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
120+
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
122121

123122
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
124123
OutDataType,
@@ -131,12 +130,12 @@ struct GroupedConvolutionBackwardDataInvoker
131130
CDEElementWise,
132131
TilePartitioner::MPerBlock,
133132
TilePartitioner::NPerBlock,
134-
GemmConfig::M_Warp,
135-
GemmConfig::N_Warp,
136-
GemmConfig::M_Warp_Tile,
137-
GemmConfig::N_Warp_Tile,
138-
GemmConfig::K_Warp_Tile,
139-
GemmConfig::TransposeC,
133+
ConvConfig::M_Warp,
134+
ConvConfig::N_Warp,
135+
ConvConfig::M_Warp_Tile,
136+
ConvConfig::N_Warp_Tile,
137+
ConvConfig::K_Warp_Tile,
138+
ConvConfig::TransposeC,
140139
memory_operation,
141140
1,
142141
true,

example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ struct GroupedConvolutionBackwardWeightInvoker
2727
using GemmShape = ck_tile::TileGemmShape<
2828
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
2929
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
30-
ck_tile::
31-
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
32-
ConvConfig::PermuteA,
33-
ConvConfig::PermuteB>;
30+
ck_tile::sequence<ConvConfig::M_Warp_Tile,
31+
ConvConfig::N_Warp_Tile,
32+
ConvConfig::K_Warp_Tile>>;
3433

3534
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
3635
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
@@ -61,7 +60,7 @@ struct GroupedConvolutionBackwardWeightInvoker
6160
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
6261
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
6362
ConvConfig::TransposeC,
64-
ConvConfig::UseStructuredSparsity,
63+
false,
6564
false, // Persistent,
6665
ConvConfig::NumWaveGroups>;
6766

example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
2929
using GemmShape = ck_tile::TileGemmShape<
3030
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
3131
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
32-
ck_tile::
33-
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
34-
ConvConfig::PermuteA,
35-
ConvConfig::PermuteB>;
32+
ck_tile::sequence<ConvConfig::M_Warp_Tile,
33+
ConvConfig::N_Warp_Tile,
34+
ConvConfig::K_Warp_Tile>>;
3635

3736
constexpr ck_tile::index_t VectorSizeA = 4;
3837
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -62,7 +61,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
6261
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
6362
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
6463
ConvConfig::TransposeC,
65-
ConvConfig::UseStructuredSparsity,
64+
false,
6665
false, // Persistent,
6766
ConvConfig::NumWaveGroups>;
6867

example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "grouped_convolution_forward_invoker.hpp"
1515
#include "run_grouped_convolution_fwd_example.inc"
1616

17-
template <template <typename PrecType> typename GemmConfig>
17+
template <template <typename PrecType> typename ConvConfig>
1818
int run_grouped_conv_fwd_example(int argc, char* argv[])
1919
{
2020
using Invoker = GroupedConvolutionForwardInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
3131
if(data_type == "fp16")
3232
{
3333
return run_grouped_conv_fwd_example_prec_type<Invoker,
34-
GemmConfig<ck_tile::half_t>,
34+
ConvConfig<ck_tile::half_t>,
3535
ck_tile::half_t>(
3636
in_layout, wei_layout, out_layout, argc, argv);
3737
}
3838
else if(data_type == "bf16")
3939
{
4040
return run_grouped_conv_fwd_example_prec_type<Invoker,
41-
GemmConfig<ck_tile::bf16_t>,
41+
ConvConfig<ck_tile::bf16_t>,
4242
ck_tile::bf16_t>(
4343
in_layout, wei_layout, out_layout, argc, argv);
4444
}

example/ck_tile/20_grouped_convolution/grouped_convolution_forward_bias_clamp.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "grouped_convolution_forward_invoker.hpp"
1515
#include "run_grouped_convolution_fwd_bias_clamp_example.inc"
1616

17-
template <template <typename PrecType> typename GemmConfig>
17+
template <template <typename PrecType> typename ConvConfig>
1818
int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
1919
{
2020
using Invoker = GroupedConvolutionForwardInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
3131
if(data_type == "fp16")
3232
{
3333
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
34-
GemmConfig<ck_tile::half_t>,
34+
ConvConfig<ck_tile::half_t>,
3535
ck_tile::half_t>(
3636
in_layout, wei_layout, out_layout, argc, argv);
3737
}
3838
else if(data_type == "bf16")
3939
{
4040
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
41-
GemmConfig<ck_tile::bf16_t>,
41+
ConvConfig<ck_tile::bf16_t>,
4242
ck_tile::bf16_t>(
4343
in_layout, wei_layout, out_layout, argc, argv);
4444
}

0 commit comments

Comments
 (0)