Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIRParallelToHerd/Launch: Convert forall to parallel at the start of the pass #861

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 50 additions & 218 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,122 +610,6 @@ class ScfParToHerdConversion : public OpRewritePattern<scf::ParallelOp> {
int firstDim;
};

class ScfForallToHerdConversion : public OpRewritePattern<scf::ForallOp> {
public:
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;

ScfForallToHerdConversion(MLIRContext *ctx,
SmallPtrSet<Operation *, 8> &filteredOps,
llvm::SmallSet<air::HerdOp, 2> &replacementOps,
int firstDim)
: OpRewritePattern(ctx), filteredOps(filteredOps),
replacementOps(replacementOps), firstDim(firstDim){};

LogicalResult matchAndRewrite(scf::ForallOp parOp,
PatternRewriter &rewriter) const override {

scf::ForallOp op = parOp;

if (!filteredOps.contains(op))
return failure();

auto loc = op.getLoc();

if (op.getRank() > 2) {
unsigned split_idx = op.getRank() - 2;
SmallVector<OpFoldResult> outerLowerBounds, outerUpperBounds, outerSteps;
SmallVector<OpFoldResult> innerLowerBounds, innerUpperBounds, innerSteps;

for (unsigned i = 0, e = split_idx; i < e; ++i) {
outerLowerBounds.push_back(op.getMixedLowerBound()[i]);
outerUpperBounds.push_back(op.getMixedUpperBound()[i]);
outerSteps.push_back(op.getMixedStep()[i]);
}
auto outerLoop = rewriter.create<scf::ParallelOp>(
loc, getValueOrCreateConstantIndexOp(rewriter, loc, outerLowerBounds),
getValueOrCreateConstantIndexOp(rewriter, loc, outerUpperBounds),
getValueOrCreateConstantIndexOp(rewriter, loc, outerSteps));
for (unsigned i = 0, e = split_idx; i < e; ++i)
op.getInductionVars()[i].replaceAllUsesWith(
outerLoop.getInductionVars()[i]);

rewriter.setInsertionPointToStart(outerLoop.getBody());

for (unsigned i = split_idx, e = op.getRank(); i < e; ++i) {
innerLowerBounds.push_back(op.getMixedLowerBound()[i]);
innerUpperBounds.push_back(op.getMixedUpperBound()[i]);
innerSteps.push_back(op.getMixedStep()[i]);
}
auto innerLoop = rewriter.create<scf::ForallOp>(
loc, innerLowerBounds, innerUpperBounds, innerSteps, ValueRange(),
std::nullopt);
for (unsigned i = split_idx, e = op.getRank(); i < e; ++i)
op.getInductionVars()[i].replaceAllUsesWith(
innerLoop.getInductionVars()[i - split_idx]);

auto &body = op.getBody()->getOperations();
innerLoop.getBody()->getOperations().splice(
innerLoop.getBody()->begin(), body, body.begin(), --body.end());
op = innerLoop;
}

SmallVector<int, 2> bounds{1, 1};
for (unsigned int i = 0; i < op.getRank(); i++) {
int64_t ub_int = op.getStaticUpperBound()[i];
int64_t step_int = op.getStaticStep()[i];
bounds[i] = ub_int / step_int;
}
SmallVector<Value, 4> args;
SmallVector<Value, 4> constants;
llvm::SetVector<Value> region_args;
getUsedValuesDefinedAbove(op.getRegion(), region_args);
for (Value v : region_args) {
if (isa_and_present<arith::ConstantOp>(v.getDefiningOp()))
constants.push_back(v);
else
args.push_back(v);
}

int idx0 = firstDim;
int idx1 = (firstDim + 1) % 2;
SmallVector<Value, 2> dims{
rewriter.create<arith::ConstantIndexOp>(loc, bounds[idx0]),
rewriter.create<arith::ConstantIndexOp>(loc, bounds[idx1])};
auto herdOp = rewriter.create<air::HerdOp>(op.getLoc(), dims, args);
auto &bb = herdOp.getBody().front();
auto ivs = op.getInductionVars();

propagateLinkWith(op, herdOp);

ivs[0].replaceAllUsesWith(herdOp.getIds()[idx0]);
if (op.getRank() == 2)
ivs[1].replaceAllUsesWith(herdOp.getIds()[idx1]);

auto &body = op.getBody()->getOperations();
bb.getOperations().splice(bb.begin(), body, body.begin(), --body.end());
rewriter.setInsertionPointToStart(&herdOp.getRegion().front());
replaceAllUsesOfConstsInRegionWithNew(constants, rewriter,
herdOp.getRegion());

int i = 0;
auto kernel_args = herdOp.getKernelArguments();
for (Value v : args)
replaceAllUsesInRegionWith(v, kernel_args[i++], herdOp.getRegion());

if (op != parOp)
rewriter.eraseOp(op);
rewriter.eraseOp(parOp);
replacementOps.insert(herdOp);

return success();
}

private:
llvm::SmallPtrSet<Operation *, 8> &filteredOps;
llvm::SmallSet<air::HerdOp, 2> &replacementOps;
int firstDim;
};

LogicalResult
getMemrefBackwardSlices(Value &memref, Operation *&memrefAlloc,
SmallVector<Operation *> &backwardSlices) {
Expand Down Expand Up @@ -997,107 +881,37 @@ class ScfParToLaunchConversion : public OpRewritePattern<scf::ParallelOp> {
bool generateSegment;
};

class ScfForallToLaunchConversion : public OpRewritePattern<scf::ForallOp> {
public:
/// Pattern to rewriter scf.forall -> scf.parallel after bufferization.
class SCFForAllToParallelOp : public OpRewritePattern<scf::ForallOp> {
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;

ScfForallToLaunchConversion(MLIRContext *ctx,
llvm::SmallSet<Operation *, 8> &filteredOps,
llvm::SmallSet<air::LaunchOp, 2> &replacementOps,
bool generateSegment)
: OpRewritePattern(ctx), filteredOps(filteredOps),
replacementOps(replacementOps), generateSegment(generateSegment){};

LogicalResult matchAndRewrite(scf::ForallOp forOp,
LogicalResult matchAndRewrite(scf::ForallOp forallOp,
PatternRewriter &rewriter) const override {

scf::ForallOp op = forOp;

if (!filteredOps.contains(op))
if (forallOp.getNumResults() != 0) {
return failure();

// if (failed(normalizeScfParallel(op, rewriter)))
// return failure();

auto loc = op.getLoc();

SmallVector<int, 4> bounds(op.getRank(), 1);
for (unsigned int i = 0; i < op.getRank(); i++) {
int64_t lb_int = op.getStaticLowerBound()[i];
int64_t ub_int = op.getStaticUpperBound()[i];
int64_t step_int = op.getStaticStep()[i];

// must start at 0
if (lb_int)
return failure();

// step must divide upper bound evenly
if (ub_int % step_int)
return failure();

ub_int = ub_int / step_int;
bounds[i] = ub_int;
}

SmallVector<Value, 4> args;
SmallVector<Value, 4> constants;
llvm::SetVector<Value> region_args;
getUsedValuesDefinedAbove(op.getRegion(), region_args);
for (Value v : region_args) {
if (isa_and_present<arith::ConstantOp>(v.getDefiningOp()))
constants.push_back(v);
else
args.push_back(v);
}

SmallVector<Value, 4> sizes;
for (auto b : bounds)
sizes.push_back(rewriter.create<arith::ConstantIndexOp>(loc, b));
auto launch = rewriter.create<air::LaunchOp>(op.getLoc(), sizes, args);

rewriter.setInsertionPointToStart(&launch.getRegion().front());

if (generateSegment) {
auto segment = generateEmptySegmentOp(rewriter, op, launch);
replaceAllUsesOfConstsInRegionWithNew(constants, rewriter,
segment.getRegion());
int i = 0;
auto kernel_args = segment.getKernelArguments();
kernel_args = kernel_args.drop_front(
launch.getIds().size() +
launch.getSize().size()); // Launch's induction vars
for (Value v : args)
replaceAllUsesInRegionWith(v, kernel_args[i++], segment.getRegion());
} else {
auto &bb = launch.getBody().front();
auto ivs = op.getInductionVars();

for (int i = 0, e = ivs.size(); i < e; i++) {
ivs[i].replaceAllUsesWith(launch.getIds()[i]);
}

auto &body = op.getBody()->getOperations();
bb.getOperations().splice(bb.begin(), body, body.begin(), --body.end());
replaceAllUsesOfConstsInRegionWithNew(constants, rewriter,
launch.getRegion());
int i = 0;
auto kernel_args = launch.getKernelArguments();
for (Value v : args)
replaceAllUsesInRegionWith(v, kernel_args[i++], launch.getRegion());
}

if (op != forOp)
op.erase();
rewriter.eraseOp(forOp);
replacementOps.insert(launch);

Location loc = forallOp.getLoc();
SmallVector<Value> lowerBounds = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> upperBounds = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> step =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
auto parallelOp = rewriter.create<scf::ParallelOp>(
loc, lowerBounds, upperBounds, step, ValueRange{},
[&](OpBuilder &b, Location bodyLoc, ValueRange ivs,
ValueRange regionArgs) {});
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
rewriter.eraseBlock(&parallelOp.getRegion().back());
// Fixup the terminator
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());
rewriter.replaceOp(forallOp, parallelOp);
return success();
}

private:
llvm::SmallSet<Operation *, 8> &filteredOps;
llvm::SmallSet<air::LaunchOp, 2> &replacementOps;
bool generateSegment;
};

struct CopyToDmaPass : public air::impl::CopyToDmaBase<CopyToDmaPass> {
Expand Down Expand Up @@ -1319,6 +1133,19 @@ struct ParallelToHerdPass
LLVM_DEBUG(llvm::outs() << "input\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Preprocessing: convert forall to parallel.
RewritePatternSet preprocPatterns(context);
preprocPatterns.add<SCFForAllToParallelOp>(context);
ConversionTarget preprocTarget(*context);
preprocTarget.addLegalDialect<scf::SCFDialect, arith::ArithDialect>();
preprocTarget.addIllegalOp<scf::ForallOp>();
if (failed(applyPartialConversion(module, preprocTarget,
std::move(preprocPatterns)))) {
signalPassFailure();
}
LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Ensure that air.dma_memcpy_nd ops between L1 and L2 are within at least
// two parent scf.parallel loops.
module.walk([&](air::DmaMemcpyNdOp op) {
Expand Down Expand Up @@ -1401,8 +1228,6 @@ struct ParallelToHerdPass
patterns.add<AffineParToHerdConversion>(context);
patterns.add<ScfParToHerdConversion>(context, filteredOps, replacementOps,
clFirstDim);
patterns.add<ScfForallToHerdConversion>(context, filteredOps,
replacementOps, clFirstDim);

ConversionTarget target(*context);

Expand Down Expand Up @@ -1448,6 +1273,19 @@ struct ParallelToLaunchPass
LLVM_DEBUG(llvm::outs() << "input\n");
LLVM_DEBUG(module.print(llvm::outs()));

// Preprocessing: convert forall to parallel.
RewritePatternSet preprocPatterns(context);
preprocPatterns.add<SCFForAllToParallelOp>(context);
ConversionTarget preprocTarget(*context);
preprocTarget.addLegalDialect<scf::SCFDialect, arith::ArithDialect>();
preprocTarget.addIllegalOp<scf::ForallOp>();
if (failed(applyPartialConversion(module, preprocTarget,
std::move(preprocPatterns)))) {
signalPassFailure();
}
LLVM_DEBUG(llvm::outs() << "ir after preprocessing\n");
LLVM_DEBUG(module.print(llvm::outs()));

llvm::SmallVector<air::LaunchOp> launchOps;
module.walk([&](air::LaunchOp op) { launchOps.push_back(op); });

Expand Down Expand Up @@ -1538,8 +1376,6 @@ struct ParallelToLaunchPass
RewritePatternSet patterns(context);
patterns.add<ScfParToLaunchConversion>(context, filteredOps, replacementOps,
clHasSegment);
patterns.add<ScfForallToLaunchConversion>(context, filteredOps,
replacementOps, clHasSegment);

ConversionTarget target(*context);

Expand Down Expand Up @@ -1611,8 +1447,6 @@ transform::ParToHerdOp::applyToOne(transform::TransformRewriter &rewriter,
filteredOps.insert(target);
patterns.add<ScfParToHerdConversion>(ctx, filteredOps, herdOps,
getFirstDim());
patterns.add<ScfForallToHerdConversion>(ctx, filteredOps, herdOps,
getFirstDim());
(void)applyPatternsGreedily(
target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
std::move(patterns));
Expand All @@ -1639,8 +1473,6 @@ transform::ParToLaunchOp::applyToOne(transform::TransformRewriter &rewriter,
filteredOps.insert(target);
patterns.add<ScfParToLaunchConversion>(ctx, filteredOps, launchOps,
getHasAirSegment());
patterns.add<ScfForallToLaunchConversion>(ctx, filteredOps, launchOps,
getHasAirSegment());
(void)applyPatternsGreedily(
target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
std::move(patterns));
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Conversion/ConvertToAIR/scf_forall_to_herd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,41 @@ func.func @scf2() {

// -----

// CHECK: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)>
// CHECK-LABEL: func.func @scf3() {
// CHECK: air.herd @herd_0 tile (%[[VAL_0:.*]], %[[VAL_1:.*]]) in (%{{.*}}=%c3{{.*}}, %{{.*}}=%c2{{.*}})
// CHECK: affine.apply [[$MAP0]](%[[VAL_1]])
// CHECK: affine.apply [[$MAP1]](%[[VAL_0]])
func.func @scf3() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
scf.forall (%i, %j) = (%c1, %c0) to (%c4, %c4)
step (%c1, %c2) {
%2 = arith.muli %i, %j : index
}
return
}

// -----

// CHECK: [[$MAP0:#map[0-9]*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: [[$MAP1:#map[0-9]*]] = affine_map<(d0) -> (d0 + 1)>
// CHECK-LABEL: func.func @scf4() {
// CHECK: air.herd @herd_0 tile (%[[VAL_0:.*]], %[[VAL_1:.*]]) in (%{{.*}}=%c3{{.*}}, %{{.*}}=%c2{{.*}})
// CHECK: affine.apply [[$MAP0]](%[[VAL_1]])
// CHECK: affine.apply [[$MAP1]](%[[VAL_0]])
func.func @scf4() {
scf.forall (%i, %j) = (1, 0) to (4, 4) step (1, 2) {
%2 = arith.muli %i, %j : index
}
return
}

// -----

// This test demonstrates that while forming air.herd we look through func.call ops, fetch
// the corresponding function declaration's 'link_with' attribute and attach it to the newly
// formed air.herd op.
Expand Down
Loading