@@ -8,7 +8,7 @@ struct GroupedConvolutionBackwardDataInvoker
88{
99
1010 template <ck_tile::index_t NDimSpatial,
11- typename GemmConfig ,
11+ typename ConvConfig ,
1212 typename InDataType,
1313 typename WeiDataType,
1414 typename AccDataType,
@@ -26,12 +26,11 @@ struct GroupedConvolutionBackwardDataInvoker
2626
2727 // Implicit GEMM Traits
2828 using GemmShape = ck_tile::TileGemmShape<
29- ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
30- ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
31- ck_tile::
32- sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
33- GemmConfig::PermuteA,
34- GemmConfig::PermuteB>;
29+ ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
30+ ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
31+ ck_tile::sequence<ConvConfig::M_Warp_Tile,
32+ ConvConfig::N_Warp_Tile,
33+ ConvConfig::K_Warp_Tile>>;
3534
3635 constexpr ck_tile::index_t VectorSizeA = 8 ;
3736 constexpr ck_tile::index_t VectorSizeB = 8 ;
@@ -40,8 +39,8 @@ struct GroupedConvolutionBackwardDataInvoker
4039 constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
4140 using TilePartitioner =
4241 ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
43- GemmConfig ::TileParitionerGroupNum,
44- GemmConfig ::TileParitionerM01>;
42+ ConvConfig ::TileParitionerGroupNum,
43+ ConvConfig ::TileParitionerM01>;
4544 using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
4645 ConvSpec,
4746 InLayout,
@@ -53,17 +52,17 @@ struct GroupedConvolutionBackwardDataInvoker
5352 VectorSizeC>;
5453
5554 using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
56- GemmConfig ::kPadM ,
57- GemmConfig ::kPadN ,
58- GemmConfig ::kPadK ,
59- GemmConfig ::DoubleSmemBuffer,
55+ ConvConfig ::kPadM ,
56+ ConvConfig ::kPadN ,
57+ ConvConfig ::kPadK ,
58+ ConvConfig ::DoubleSmemBuffer,
6059 typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
6160 typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
6261 typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
63- GemmConfig ::TransposeC,
64- GemmConfig::UseStructuredSparsity ,
62+ ConvConfig ::TransposeC,
63+ false ,
6564 false , // Persistent,
66- GemmConfig ::NumWaveGroups>;
65+ ConvConfig ::NumWaveGroups>;
6766
6867 using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
6968 OutDataType,
@@ -79,16 +78,16 @@ struct GroupedConvolutionBackwardDataInvoker
7978 VectorSizeB>;
8079
8180 using BaseGemmPipeline = typename PipelineTypeTraits<
82- GemmConfig ::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
81+ ConvConfig ::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
8382
8483 const ck_tile::index_t gemm_k =
8584 args.K_ * std::accumulate (args.filter_spatial_lengths_ .begin (),
8685 args.filter_spatial_lengths_ .end (),
8786 1 ,
8887 std::multiplies<ck_tile::index_t >());
8988
90- const ck_tile::index_t k_grain = args.k_batch * GemmConfig ::K_Tile;
91- const ck_tile::index_t K_split = (gemm_k + k_grain - 1 ) / k_grain * GemmConfig ::K_Tile;
89+ const ck_tile::index_t k_grain = args.k_batch * ConvConfig ::K_Tile;
90+ const ck_tile::index_t K_split = (gemm_k + k_grain - 1 ) / k_grain * ConvConfig ::K_Tile;
9291 const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum (K_split);
9392 const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop (num_loop);
9493 const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum (num_loop);
@@ -98,7 +97,7 @@ struct GroupedConvolutionBackwardDataInvoker
9897 [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
9998 constexpr bool has_hot_loop_v = has_hot_loop_.value ;
10099 constexpr auto tail_number_v = tail_number_.value ;
101- constexpr auto scheduler = GemmConfig ::Scheduler;
100+ constexpr auto scheduler = ConvConfig ::Scheduler;
102101 constexpr auto memory_operation = memory_operation_.value ;
103102
104103 using UniversalGemmProblem =
@@ -118,7 +117,7 @@ struct GroupedConvolutionBackwardDataInvoker
118117 VectorSizeB>;
119118
120119 using GemmPipeline = typename PipelineTypeTraits<
121- GemmConfig ::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
120+ ConvConfig ::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
122121
123122 using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
124123 OutDataType,
@@ -131,12 +130,12 @@ struct GroupedConvolutionBackwardDataInvoker
131130 CDEElementWise,
132131 TilePartitioner::MPerBlock,
133132 TilePartitioner::NPerBlock,
134- GemmConfig ::M_Warp,
135- GemmConfig ::N_Warp,
136- GemmConfig ::M_Warp_Tile,
137- GemmConfig ::N_Warp_Tile,
138- GemmConfig ::K_Warp_Tile,
139- GemmConfig ::TransposeC,
133+ ConvConfig ::M_Warp,
134+ ConvConfig ::N_Warp,
135+ ConvConfig ::M_Warp_Tile,
136+ ConvConfig ::N_Warp_Tile,
137+ ConvConfig ::K_Warp_Tile,
138+ ConvConfig ::TransposeC,
140139 memory_operation,
141140 1 ,
142141 true ,
0 commit comments