Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions example/ck_tile/20_grouped_convolution/conv_configs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ struct ConvConfigBase
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;

static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;

static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool TransposeC = false;

static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
Expand All @@ -34,8 +30,6 @@ struct ConvConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool TiledMMAPermuteN = false;

static constexpr ck_tile::index_t NumGroupsToMerge = 1;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "grouped_convolution_backward_data_invoker.hpp"
#include "run_grouped_convolution_bwd_data_example.inc"

template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionBackwardDataInvoker;
Expand All @@ -31,14 +31,14 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct GroupedConvolutionBackwardDataInvoker
{

template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
Expand All @@ -26,12 +26,11 @@ struct GroupedConvolutionBackwardDataInvoker

// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;

constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
Expand All @@ -40,8 +39,8 @@ struct GroupedConvolutionBackwardDataInvoker
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
Expand All @@ -53,17 +52,17 @@ struct GroupedConvolutionBackwardDataInvoker
VectorSizeC>;

using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
false,
false, // Persistent,
GemmConfig::NumWaveGroups>;
ConvConfig::NumWaveGroups>;

using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
Expand All @@ -79,16 +78,16 @@ struct GroupedConvolutionBackwardDataInvoker
VectorSizeB>;

using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;

const ck_tile::index_t gemm_k =
args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());

const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
Expand All @@ -98,7 +97,7 @@ struct GroupedConvolutionBackwardDataInvoker
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;

using UniversalGemmProblem =
Expand All @@ -118,7 +117,7 @@ struct GroupedConvolutionBackwardDataInvoker
VectorSizeB>;

using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;

using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
Expand All @@ -131,12 +130,12 @@ struct GroupedConvolutionBackwardDataInvoker
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ struct GroupedConvolutionBackwardWeightInvoker
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;

constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
Expand Down Expand Up @@ -61,7 +60,7 @@ struct GroupedConvolutionBackwardWeightInvoker
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false,
false, // Persistent,
ConvConfig::NumWaveGroups>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;

constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
Expand Down Expand Up @@ -62,7 +61,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false,
false, // Persistent,
ConvConfig::NumWaveGroups>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"

template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
Expand All @@ -31,14 +31,14 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_bias_clamp_example.inc"

template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
Expand All @@ -31,14 +31,14 @@ int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
Expand Down
Loading