Skip to content

Commit

Permalink
Refactor dataflow bufferization
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Jan 9, 2024
1 parent c9d7312 commit 4bf297f
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 119 deletions.
1 change: 0 additions & 1 deletion include/scalehls/Dialect/HLS/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ namespace mlir {
namespace scalehls {
namespace hls {

std::unique_ptr<Pass> createEliminateBufferYieldPass();
std::unique_ptr<Pass> createLowerDataflowPass();

#define GEN_PASS_CLASSES
Expand Down
6 changes: 0 additions & 6 deletions include/scalehls/Dialect/HLS/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@

include "mlir/Pass/PassBase.td"

def EliminateBufferYield :
Pass<"scalehls-fdf-eliminate-buffer-yield", "func::FuncOp"> {
let summary = "Eliminate unecessary buffer yield operations";
let constructor = "mlir::scalehls::hls::createEliminateBufferYieldPass()";
}

def LowerDataflow : Pass<"scalehls-lower-dataflow", "func::FuncOp"> {
let summary = "Convert functional to structural dataflow";
let constructor = "mlir::scalehls::hls::createLowerDataflowPass()";
Expand Down
28 changes: 28 additions & 0 deletions lib/Dialect/HLS/IR/HLSOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,39 @@ struct InlineDispatchOrTask : public OpRewritePattern<OpType> {
};
} // namespace

namespace {
template <typename OpType>
struct DemoteYieldedBuffer : public OpRewritePattern<OpType> {
using OpRewritePattern<OpType>::OpRewritePattern;

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
auto yield = op.getYieldOp();
bool hasChanged = false;

// Eliminat each yielded buffer. It's always safe to move the buffer to
// higher level hierarchy.
for (auto [yieldedValue, result] :
llvm::zip(yield.getOperands(), op.getResults()))
if (auto buffer = yieldedValue.template getDefiningOp<BufferOp>()) {
if (op->isAncestor(buffer))
buffer->moveBefore(op);

rewriter.replaceAllUsesWith(result, buffer);
hasChanged = true;
}
return success(hasChanged);
}
};
} // namespace

void DispatchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyDispatchOrTaskOutputs<DispatchOp>>(context);
results.add<InlineDispatchOrTask<DispatchOp>>(context, [](DispatchOp op) {
return op.getOps<TaskOp>().empty() || llvm::hasSingleElement(op.getOps());
});
results.add<DemoteYieldedBuffer<DispatchOp>>(context);
}

LogicalResult DispatchOp::verify() {
Expand All @@ -111,6 +138,7 @@ void TaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyDispatchOrTaskOutputs<TaskOp>>(context);
results.add<InlineDispatchOrTask<TaskOp>>(
context, [](TaskOp op) { return llvm::hasSingleElement(op.getOps()); });
results.add<DemoteYieldedBuffer<TaskOp>>(context);
}

LogicalResult TaskOp::verify() {
Expand Down
110 changes: 71 additions & 39 deletions lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

#include "scalehls/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "scalehls/Dialect/HLS/IR/HLS.h"

using namespace mlir;
Expand All @@ -32,20 +35,19 @@ struct DispatchOrTaskOpInterface
AliasingOpOperandList
getAliasingOpOperands(Operation *op, Value value,
const AnalysisState &state) const {
size_t resultNum = std::distance(op->getResults().begin(),
llvm::find(op->getResults(), value));
OpOperand *operand =
&cast<OpType>(op).getYieldOp()->getOpOperand(resultNum);
OpOperand *operand = &cast<OpType>(op).getYieldOp()->getOpOperand(
cast<OpResult>(value).getResultNumber());
return {{operand, BufferRelation::Equivalent}};
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto concreteOp = cast<OpType>(op);

// Compute bufferized result types.
SmallVector<Type> newTypes;
for (Value result : op->getResults()) {
for (Value result : concreteOp.getResults()) {
if (!result.getType().isa<TensorType>()) {
newTypes.push_back(result.getType());
continue;
Expand All @@ -57,13 +59,13 @@ struct DispatchOrTaskOpInterface
}

// Create new dispatch/task op.
rewriter.setInsertionPoint(op);
auto newOp = rewriter.create<OpType>(op->getLoc(), newTypes);
rewriter.inlineRegionBefore(cast<OpType>(op).getBody(), newOp.getBody(),
rewriter.setInsertionPoint(concreteOp);
auto newOp = rewriter.create<OpType>(concreteOp.getLoc(), newTypes);
rewriter.inlineRegionBefore(concreteOp.getBody(), newOp.getBody(),
newOp.getBody().end());

// Replace dispatch/task op results.
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
replaceOpWithBufferizedValues(rewriter, concreteOp, newOp->getResults());
return success();
}

Expand All @@ -86,10 +88,11 @@ struct DispatchOrTaskOpInterface
}
};

/// Bufferization of fdf.yield operation. Bufferized as part of their enclosing
/// ops, so this is for analysis only.
/// Bufferization of fdf.yield operation.
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface, YieldOp> {
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }

bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
Expand Down Expand Up @@ -117,30 +120,71 @@ struct YieldOpInterface

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
SmallVector<Value> newResults;
for (const auto value : cast<YieldOp>(op).getResults()) {
if (value.getType().isa<TensorType>()) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
OpBuilder::InsertionGuard g(rewriter);
auto yield = cast<YieldOp>(op);
auto parent = yield->getParentOp();

// Traverse and bufferize each operand of the yield operation.
for (auto operand : yield.getOperands()) {
if (!operand.getType().isa<TensorType>())
continue;

auto maybeBuffer = getBuffer(rewriter, operand, options);
auto maybeType = bufferization::getBufferType(operand, options);
if (failed(maybeBuffer) || failed(maybeType))
continue;

// For now, we always generate an explicit copy to handle view-like
// operations. This is not efficient but it's safe.
if (auto view = maybeBuffer->getDefiningOp<ViewLikeOpInterface>()) {
rewriter.setInsertionPoint(parent);
auto localBuffer = options.createAlloc(
rewriter, yield.getLoc(), maybeType->cast<MemRefType>(), {});
if (failed(localBuffer))
return failure();

rewriter.setInsertionPoint(yield);
if (failed(options.createMemCpy(rewriter, yield.getLoc(), *maybeBuffer,
*localBuffer)))
return failure();
newResults.push_back(*maybeBuffer);

rewriter.replaceUsesWithIf(operand, *localBuffer, [&](OpOperand &use) {
return use.getOwner() == yield;
});
} else {
newResults.push_back(value);
rewriter.setInsertionPoint(yield);
auto replacement = rewriter.create<bufferization::ToMemrefOp>(
yield.getLoc(), *maybeType, operand);
rewriter.replaceUsesWithIf(operand, replacement, [&](OpOperand &use) {
return use.getOwner() == yield;
});
}
}
replaceOpWithNewBufferizedOp<YieldOp>(rewriter, op, newResults);
return success();
}
};

/// Bufferization of fdf.alloc_tensor operation.
struct AllocTensorOpInterface
: public BufferizableOpInterface::ExternalModel<AllocTensorOpInterface,
AllocTensorOp> {
hls::AllocTensorOp> {
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }

bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return false;
}

AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// This is a new allocation. It does not alias with any other buffer.
return {};
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto allocTensor = cast<AllocTensorOp>(op);
auto allocTensor = cast<hls::AllocTensorOp>(op);

// Nothing to do for dead AllocTensorOps.
if (allocTensor->getUses().empty()) {
Expand All @@ -149,12 +193,13 @@ struct AllocTensorOpInterface
}

// Create memory allocation.
auto allocType =
auto maybeType =
bufferization::getBufferType(allocTensor.getResult(), options);
if (failed(allocType))
if (failed(maybeType))
return failure();

FailureOr<Value> buffer = options.createAlloc(
rewriter, allocTensor.getLoc(), allocType->cast<MemRefType>(), {});
rewriter, allocTensor.getLoc(), maybeType->cast<MemRefType>(), {});
if (failed(buffer))
return failure();

Expand All @@ -172,23 +217,10 @@ struct AllocTensorOpInterface
return success();
}

bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return false;
}

bool bufferizesToAllocation(Operation *op, Value value) const { return true; }

AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// This is a new allocation. It does not alias with any other buffer.
return {};
}

FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto allocTensor = cast<AllocTensorOp>(op);
auto allocTensor = cast<hls::AllocTensorOp>(op);
assert(value == allocTensor.getResult() && "invalid value");

// Compute memory space of this allocation.
Expand All @@ -209,6 +241,6 @@ void mlir::scalehls::hls::registerBufferizableOpInterfaceExternalModels(
DispatchOp::attachInterface<DispatchOrTaskOpInterface<DispatchOp>>(*ctx);
TaskOp::attachInterface<DispatchOrTaskOpInterface<TaskOp>>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
AllocTensorOp::attachInterface<AllocTensorOpInterface>(*ctx);
hls::AllocTensorOp::attachInterface<AllocTensorOpInterface>(*ctx);
});
}
1 change: 0 additions & 1 deletion lib/Dialect/HLS/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
add_mlir_dialect_library(MLIRScaleHLSHLSTransforms
BufferizableOpInterfaceImpl.cpp
LowerDataflow.cpp
EliminateBufferYield.cpp

DEPENDS
MLIRScaleHLSHLSTransformsIncGen
Expand Down
70 changes: 0 additions & 70 deletions lib/Dialect/HLS/Transforms/EliminateBufferYield.cpp

This file was deleted.

11 changes: 11 additions & 0 deletions lib/Transforms/ConvertLinalgToDataflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ struct DispatchFuncOp : public OpRewritePattern<func::FuncOp> {
if (!dispatch)
return failure();

// Ensure each AllocTensorOp is only used once.
for (auto allocTensor :
llvm::make_early_inc_range(dispatch.getOps<hls::AllocTensorOp>())) {
for (auto &use : llvm::make_early_inc_range(allocTensor->getUses())) {
rewriter.setInsertionPoint(use.getOwner());
auto newAllocTensor =
cast<hls::AllocTensorOp>(rewriter.clone(*allocTensor));
use.set(newAllocTensor);
}
}

for (auto &op : llvm::make_early_inc_range(dispatch.getOps())) {
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (linalgOp.hasDynamicShape())
Expand Down
2 changes: 0 additions & 2 deletions lib/Transforms/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ void scalehls::addComprehensiveBufferizePasses(OpPassManager &pm) {
// can be deleted by canonicalizer. We have to run it again because the
// memrefs are unified in CSE pass, so we can truely remove redundant memcpy.
pm.addPass(mlir::createCanonicalizerPass());
// pm.addNestedPass<func::FuncOp>(hls::createEliminateBufferYieldPass());
// pm.addPass(mlir::createCanonicalizerPass());
}

void scalehls::addLowerDataflowPasses(OpPassManager &pm) {
Expand Down

0 comments on commit 4bf297f

Please sign in to comment.