Skip to content

Commit b0282e7

Browse files
committed
Add persistent-dp option to streamk example
1 parent bcaf825 commit b0282e7

File tree

4 files changed

+31
-10
lines changed

4 files changed

+31
-10
lines changed

example/ck_tile/40_streamk_gemm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ args:
2323
-b_layout tensor B data layout (default: C)
2424
-c_layout tensor C data layout (default: R)
2525
-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic)
26+
-persistent_dp persistent strategy for data-parallel section. 0. Non-persistent, 1 persistent.")
2627
-stride_a tensor A stride (default:0)
2728
-stride_b tensor B stride (default:0)
2829
-stride_c tensor C stride (default:0)

example/ck_tile/40_streamk_gemm/gemm_utils.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ struct GemmConfigBase
1818

1919
static constexpr bool TransposeC = false;
2020
static constexpr bool UseStructuredSparsity = false;
21-
static constexpr bool Persistent = false;
2221

2322
static constexpr int kBlockPerCu = 1;
2423
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
@@ -27,7 +26,7 @@ struct GemmConfigBase
2726
static constexpr bool DoubleSmemBuffer = false;
2827
};
2928

30-
template <typename PrecType>
29+
template <typename PrecType, bool Persistent_>
3130
struct GemmConfigMemoryInterwave : public GemmConfigBase
3231
{
3332
static constexpr ck_tile::index_t M_Tile = 256;
@@ -42,7 +41,8 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
4241
static constexpr ck_tile::index_t N_Warp_Tile = 32;
4342
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
4443

45-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
44+
static constexpr bool Persistent = Persistent_;
45+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
4646
};
4747

4848
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
@@ -87,6 +87,9 @@ auto create_args(int argc, char* argv[])
8787
.insert("reduction_strategy",
8888
"atomic",
8989
"strategy for storing results in C tensor - atomic/reduction")
90+
.insert("persistent_dp",
91+
"0",
92+
"0. Non-persistent data-parallel section, 1 Fully persistent kernel.")
9093
.insert("stride_a", "0", "Tensor A stride")
9194
.insert("stride_b", "0", "Tensor B stride")
9295
.insert("stride_c", "0", "Tensor C stride")

example/ck_tile/40_streamk_gemm/run_gemm_example.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ int run_gemm_example_with_layouts(int argc,
275275
<< " B_Type=" << DataTypeTraits<BDataType>::name
276276
<< " C_Type=" << DataTypeTraits<CDataType>::name
277277
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
278-
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
279-
<< std::endl;
278+
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
279+
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
280280

281281
bool pass = true;
282282

example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
169169
return 0;
170170
}
171171

172-
template <template <typename PreType> typename GemmConfig>
172+
template <template <typename PreType, bool Persistent_> typename GemmConfig>
173173
int run_gemm_example(int argc, char* argv[])
174174
{
175175
auto [result, arg_parser] = create_args(argc, argv);
@@ -179,18 +179,35 @@ int run_gemm_example(int argc, char* argv[])
179179
std::string data_type = arg_parser.get_str("prec");
180180
std::string a_layout = arg_parser.get_str("a_layout");
181181
std::string b_layout = arg_parser.get_str("b_layout");
182+
auto persistent_dp = arg_parser.get_bool("persistent_dp");
182183

183184
if(data_type == "bf16")
184185
{
185186
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
186-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, TypeConfig>(
187-
a_layout, b_layout, argc, argv);
187+
if(persistent_dp)
188+
{
189+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
190+
a_layout, b_layout, argc, argv);
191+
}
192+
else
193+
{
194+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
195+
a_layout, b_layout, argc, argv);
196+
}
188197
}
189198
else if(data_type == "fp16")
190199
{
191200
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
192-
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
193-
a_layout, b_layout, argc, argv);
201+
if(persistent_dp)
202+
{
203+
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
204+
a_layout, b_layout, argc, argv);
205+
}
206+
else
207+
{
208+
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
209+
a_layout, b_layout, argc, argv);
210+
}
194211
}
195212
else
196213
{

0 commit comments

Comments
 (0)