Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add knobs control inner dim unroll and outer dim unroll in pointwise scheduler #3275

Merged
merged 27 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
12568b6
unroll the outer dim
liqiangxl Oct 17, 2024
09ba0b6
unroll the outer dim
liqiangxl Oct 17, 2024
c86a0b2
Merge branch 'llu/unroll_outer_dim' of https://github.com/nvidia/fuse…
liqiangxl Oct 18, 2024
f5b349f
comment
liqiangxl Oct 19, 2024
23efc80
enable unroll
liqiangxl Oct 19, 2024
67450af
adjust bdimx for divisible split
liqiangxl Oct 19, 2024
35af092
test heurs
liqiangxl Oct 20, 2024
089dd85
unroll inner and outer
liqiangxl Oct 22, 2024
e2221b5
merge main
liqiangxl Oct 25, 2024
00571b4
wip
liqiangxl Oct 25, 2024
377b7fc
tests
liqiangxl Oct 25, 2024
12ad2e6
clean
liqiangxl Oct 25, 2024
94680b6
python
liqiangxl Oct 25, 2024
7e04577
Merge branch 'main' into llu/ps_unroll_inner_outer
liqiangxl Oct 25, 2024
c5b0365
fix pos
liqiangxl Oct 25, 2024
1307ba8
merge
liqiangxl Oct 26, 2024
4ea639b
Merge branch 'main' into llu/ps_unroll_inner_outer
liqiangxl Oct 27, 2024
6c22a3b
clean
liqiangxl Oct 27, 2024
b23cb41
split even outer unroll factor == 1, should drop this commit, test co…
liqiangxl Oct 28, 2024
7e0cbff
Revert "split even outer unroll factor == 1, should drop this commit,…
liqiangxl Oct 29, 2024
50ba432
Merge branch 'main' into llu/ps_unroll_inner_outer
liqiangxl Oct 29, 2024
cb60e11
set unroll factor based on 1d or 2d scheduler
liqiangxl Oct 30, 2024
8616bc9
add comment
liqiangxl Oct 30, 2024
416bc31
Merge branch 'main' into llu/ps_unroll_inner_outer
liqiangxl Oct 31, 2024
63a38f1
Merge branch 'main' into llu/ps_unroll_inner_outer
liqiangxl Nov 1, 2024
5293c7e
Merge branch 'main' into llu/ps_unroll_inner_outer
liqiangxl Nov 1, 2024
0710686
Merge branch 'main' into llu/ps_unroll_inner_outer
liqiangxl Nov 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions benchmarks/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ std::string toString(const PointwiseParams* pparams) {
if (pparams->vectorization_factor > 1) {
ss << "Vectorize, Factor: " << pparams->vectorization_factor << "\n";
}
if (pparams->unroll_factor > 1) {
ss << "Unroll, Factor: " << pparams->unroll_factor << "\n";
if (pparams->unroll_factor_outer > 1) {
ss << "Outer Unroll, Factor: " << pparams->unroll_factor_outer << "\n";
}
if (pparams->unroll_factor_inner > 1) {
ss << "Inner Unroll, Factor: " << pparams->unroll_factor_inner << "\n";
}
return ss.str();
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,8 @@ void defineHeuristicParamBindings(py::module& nvfuser) {
.PARAM(PointwiseParams, split_grid_y_dim)
.PARAM(PointwiseParams, flip_grid_binding)
.PARAM(PointwiseParams, vectorization_factor)
.PARAM(PointwiseParams, unroll_factor);
.PARAM(PointwiseParams, unroll_factor_inner)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼

.PARAM(PointwiseParams, unroll_factor_outer);

// Matmul scheduler parameters
INITHEURISTICPARAMS(MatmulParams)
Expand Down
74 changes: 51 additions & 23 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,10 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
// not used. should allow to use both unroll and vectorization together in
// heuristics tuning.
if (params->vectorization_factor == 1) {
params->unroll_factor = scheduler_utils::safeDiv(
auto total_unroll = scheduler_utils::safeDiv(
max_vect_unroll_factor, params->vectorization_factor);
params->unroll_factor_inner = total_unroll;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this PR shouldn't impose any functional changes. So I would expect all old use of params->unroll_factor to be replaced with params->unroll_factor_inner.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. here all the unroll factors go to unroll_factor_inner through params->unroll_factor_inner = total_unroll;

params->unroll_factor_outer = 1L;
}

NVF_ERROR(right_elem_count > 0 || break_point == 0);
Expand All @@ -394,7 +396,10 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
<< "num_elems: " << n_elems << "\n"
<< "elem_counts: " << elem_counts << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "unroll_factor: " << params->unroll_factor << std::endl
<< "unroll_factor_inner: " << params->unroll_factor_inner
<< std::endl
<< "unroll_factor_outer: " << params->unroll_factor_outer
<< std::endl
<< "vectorize_factor: " << params->vectorization_factor << std::endl
<< "\n"
<< "logical_reorder_map: ";
Expand Down Expand Up @@ -677,7 +682,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
reference_tv->reorder({{lhs_i, 0}, {-1, 1}});

// vectorization without unroll
if (pparams->unroll_factor == 1 && pparams->vectorization_factor > 1) {
if (pparams->unroll_factor_outer == 1 &&
pparams->unroll_factor_inner == 1 &&
pparams->vectorization_factor > 1) {
reference_tv->split(1, pparams->vectorization_factor);
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
reference_tv->split(0, 1);
Expand All @@ -700,41 +707,58 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
// [outer | i-remainder, TIDx, Vect]

reference_tv->split(0, pparams->unroll_factor);
// [o-remainder, Unroll| i-remainder, TIDx, Vect]
if (pparams->unroll_factor_inner > 1) {
reference_tv->split(1, pparams->unroll_factor_inner);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are splitting on dimension 1? which is the TIDx here right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for 2D scheduler, start with [outer dim, inner dim], so here dimension 1 is i-remainder in [0-outer | 1-i-remainder, 2-TIDx, 3-Vect]. i-remainder means what is left after splitting out other dims, e.g. Vect, TIDx

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a behavior change then.

If we look at the above commented code change, we are doing

-      reference_tv->split(0, pparams->unroll_factor);
-      // [o-remainder, Unroll| i-remainder, TIDx, Vect]
+      if (pparams->unroll_factor_inner > 1) {
+       reference_tv->split(1, pparams->unroll_factor_inner);

Which means the old behavior (outer unroll) is being updated to a default inner unroll instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Should assign unroll to inner dim only when the scheduler is 1D, for 2D should assign to outer dim.

    // for 1D scheduler, unroll the inner dimension
    // since there is no outer dimension.
    if (break_point == 0) {
      params->unroll_factor_inner = total_unroll;
      params->unroll_factor_outer = 1L;
    } else {
      // for 2D scheduler, unroll the outer dimension
      // to prioritize resue across different rows, will
      // be revised in heuristics tuning, e.g. unroll different
      // dims based on the broadcast dimension.
      params->unroll_factor_inner = 1L;
      params->unroll_factor_outer = total_unroll;
    }

}
// [outer| i-remainder, i-Unroll, TIDx, Vect]

if (pparams->unroll_factor_outer > 1) {
reference_tv->split(0, pparams->unroll_factor_outer);
}
// [o-remainder, o-Unroll| i-remainder, i-Unroll, TIDx, Vect]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit lost about the notation here. What's o-Unroll | i-remainder?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o represents outer dim and i represents inner dim. | sperates inner dim and outer dim. So here o-Unroll represents outer unroll and i-remainder means what is left in the inner dim after splitting out other domains, e.g. Vect, TIDx

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used o-Unroll and i-Unroll to distinguish between unroll in outer dim and inner dim.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, sorry I was totally not getting | part here. Now it reads clear to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments for clarity.

      // Here and in the following comments:
      // prefix [i] represents inner dimension
      // prefix [o] represents inner dimension
      // [|] separates the outer and inner dimensions


reference_tv->split(0, 1);
// [o-remainder, Unswitch, Unroll | i-remainder, TIDx, Vect]
// [o-remainder, Unswitch, o-Unroll | i-remainder, i-Unroll, TIDx, Vect]

reference_tv->reorder({{3, 1}});
// [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect]
int i_remainder_pos = pparams->unroll_factor_outer > 1 ? 3 : 2;
reference_tv->reorder({{i_remainder_pos, 1}});
// [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect]

reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(3)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(4)->parallelize(ParallelType::TIDx);
int tidx_pos = 3;
if (pparams->unroll_factor_inner > 1) {
tidx_pos++;
}
if (pparams->unroll_factor_outer > 1) {
tidx_pos++;
}
reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx);
if (pparams->vectorization_factor > 1) {
vectorize_id = reference_tv->axis(5);
// can't use {-1}, there may be deviceId
vectorize_id = reference_tv->axis(tidx_pos + 1);
}
// [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect]
// [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect]
}

// Move out of the way to furthest left point
reference_tv->reorder({{1, 0}});
// [i-remainder, o-remainder, Unswitch, Unroll, TIDx, Vect]
// [i-remainder, o-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect]
if (pparams->split_block) {
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
// [i-remainder, o-remainder, TIDy, Unswitch, Unroll, TIDx, Vect]
// [i-remainder, o-remainder, TIDy, Unswitch, o-Unroll, i-Unroll, TIDx,
// Vect]
if (pparams->flip_grid_binding) {
// [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx, Vect]
// [BIDy | BIDx, TIDy | Unswitch, o-Unroll, i-Unroll, TIDx, Vect]
reference_tv->axis(1)->parallelize(ParallelType::BIDx);
reference_tv->axis(2)->parallelize(ParallelType::TIDy);
if (pparams->split_grid_y_dim) {
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx,
// Vect]
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, o-Unroll,
// i-Unroll, TIDx, Vect]
reference_tv->split(0, 65535);
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 5;
Expand All @@ -743,12 +767,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
unswitch_pos = 4;
}
} else {
// [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx, Vect]
// [BIDx | BIDy TIDy | Unswitch, o-Unroll, i-Unroll, TIDx, Vect]
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(2)->parallelize(ParallelType::TIDy);
if (pparams->split_grid_y_dim) {
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx,
// Vect]
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, o-Unroll,
// i-Unroll, TIDx, Vect]
reference_tv->split(1, 65535);
reference_tv->axis(2)->parallelize(ParallelType::BIDy);
unswitch_pos = 5;
Expand Down Expand Up @@ -796,7 +820,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
// unmerged...]
reference_tv->reorder({{-1, 0}});

if (pparams->unroll_factor == 1 && pparams->vectorization_factor > 1) {
if (pparams->unroll_factor_inner == 1 &&
pparams->vectorization_factor > 1) {
// Vectorize
reference_tv->split(0, pparams->vectorization_factor);
// Unswitch
Expand All @@ -822,7 +847,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
// Threads
reference_tv->split(0, kThreadX);
// Unroll
reference_tv->split(0, pparams->unroll_factor);
if (pparams->unroll_factor_inner > 1) {
reference_tv->split(0, pparams->unroll_factor_inner);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: we are not using unroll_factor_outer in this branch, is that expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this else branch is for 1D scheduler, all IDs are merged into 1 domain, there is no outer dim.

}
// Unswitch
reference_tv->split(0, 1);

Expand All @@ -834,9 +861,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
int tidx_pos = pparams->unroll_factor_inner > 1 ? 3 : 2;
reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx);
if (pparams->vectorization_factor > 1) {
vectorize_id = reference_tv->axis(4);
vectorize_id = reference_tv->axis(tidx_pos + 1);
}
}
unswitch_pos = 2;
Expand Down
23 changes: 14 additions & 9 deletions csrc/scheduler/pointwise_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ class PointwiseParams : public HeuristicParams {
// Unroll on top of vectorization
// In the 2D scheduler, unroll the outer dimension to reuse loaded data across
// rows, reducing loaded bytes by the unroll factor.
int64_t unroll_factor = 1;
// Always equals 1 for 1D scheduler.
int64_t unroll_factor_outer = 1;

// In the 2D scheduler, unroll the inner dimension to reuse loaded data across
// cols, reducing loaded bytes by the unroll factor.
// Also used in 1D scheduler.
int64_t unroll_factor_inner = 1;

using HeuristicParams::HeuristicParams;

Expand All @@ -60,7 +66,8 @@ class PointwiseParams : public HeuristicParams {
other->break_point == break_point &&
other->split_block == split_block &&
other->split_grid_y_dim == split_grid_y_dim &&
other->unroll_factor == unroll_factor &&
other->unroll_factor_outer == unroll_factor_outer &&
other->unroll_factor_inner == unroll_factor_inner &&
other->flip_grid_binding == flip_grid_binding;
return attr_equal;
}
Expand All @@ -81,12 +88,9 @@ class PointwiseParams : public HeuristicParams {
ss << " Split y grid dim\n";
}
}
if (vectorization_factor > 1) {
ss << "Vectorize, Factor: " << vectorization_factor << "\n";
}
if (unroll_factor > 1) {
ss << "Unroll, Factor: " << unroll_factor << "\n";
}
ss << "vectorization_factor: " << vectorization_factor << "\n";
ss << "unroll_factor_outer: " << unroll_factor_outer << "\n";
ss << "unroll_factor_inner: " << unroll_factor_inner << "\n";
if (flip_grid_binding) {
ss << "Flip BIDx/BIDy bindings\n";
}
Expand All @@ -100,7 +104,8 @@ class PointwiseParams : public HeuristicParams {
static_cast<size_t>(break_point) << 4 ^
static_cast<size_t>(split_block) << 5 ^
static_cast<size_t>(split_grid_y_dim) << 6 ^
static_cast<size_t>(unroll_factor) << 9 ^
static_cast<size_t>(unroll_factor_outer) << 7 ^
static_cast<size_t>(unroll_factor_inner) << 9 ^
static_cast<size_t>(flip_grid_binding) << 10;
return attr_hash;
}
Expand Down
37 changes: 30 additions & 7 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,9 @@ TEST_F(PointwiseTest, VectorizeWithExpandedBroadcast) {
EXPECT_GT(getVecSizeForPointwise(fec), 1);
}

TEST_F(PointwiseTest, UnrollOnTopOfVectorize) {
using VectUnrollFactors = std::tuple<int64_t, int64_t, int64_t>;
using PointwiseParamsTest = NVFuserFixtureParamTest<VectUnrollFactors>;
TEST_P(PointwiseParamsTest, UnrollOnTopOfVectorize) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand All @@ -685,25 +687,46 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) {
auto t1 = at::randn({dim1}, options);
std::vector<c10::IValue> runtime_inputs{t0, t1};

// generate heuristics
// Generate heuristics
SchedulerRuntimeInfo runtime_info(fusion.get(), runtime_inputs);
auto scheduler_instance =
SchedulerEntry::makeSchedulerInstance(SchedulerType::PointWise);
auto heuristic_params =
scheduler_instance->computeHeuristics(fusion.get(), runtime_info);
auto pparams = heuristic_params->as<PointwiseParams>();

// modify heuristics to enforce unroll on top of vectorization
pparams->vectorization_factor = 4;
pparams->unroll_factor = 2;
// Modify heuristics to enforce unroll on top of vectorization

// schedule, compile, run, validate
// Set unroll factors from test parameters
auto [vect_factor, unroll_inner, unroll_outer] = GetParam();
pparams->unroll_factor_inner = unroll_inner;
pparams->unroll_factor_outer = unroll_outer;
pparams->vectorization_factor = vect_factor;

// Schedule, compile, run, validate
scheduler_instance->schedule(fusion.get(), pparams);
FusionExecutor fe;
fe.compileFusion(fusion.get(), runtime_inputs, pparams->lparams);
auto cg_outputs = fe.runFusion(runtime_inputs, pparams->lparams);
const auto& lparams = fe.lastLaunchParams();
ASSERT_EQ(lparams.gdimy(), dim0 / pparams->unroll_factor);
ASSERT_EQ(lparams.gdimy(), dim0 / unroll_outer);
ASSERT_EQ(
lparams.gdimx(), dim1 / vect_factor / lparams.bdimx() / unroll_inner);
testValidate(fusion.get(), cg_outputs, runtime_inputs, __LINE__, __FILE__);
}
INSTANTIATE_TEST_SUITE_P(
,
PointwiseParamsTest,
::testing::Combine(
testing::Values(1, 4), // vectorization factors
testing::Values(1, 2), // inner unroll factors
testing::Values(1, 2) // outer unroll factors
),
[](const testing::TestParamInfo<VectUnrollFactors>& info) -> std::string {
std::stringstream ss;
ss << "vect_" << std::get<0>(info.param);
ss << "_inner_unroll_" << std::get<1>(info.param);
ss << "_outer_unroll_" << std::get<2>(info.param);
return sanitizeTestName(ss.str());
});
} // namespace nvfuser
Loading