Skip to content

Commit

Permalink
Rewrite the operation id assignment, so that input ops, both sync and…
Browse files Browse the repository at this point in the history
… async, get unique ids for the dependency graph
  • Loading branch information
erwei-xilinx committed Mar 7, 2025
1 parent cfe7094 commit 8aa8318
Showing 1 changed file with 50 additions and 44 deletions.
94 changes: 50 additions & 44 deletions mlir/lib/Transform/AIRDependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
HierarchyOpID = 0;
WaitAllOpID = 0;
ChannelOpID = 0;
DmaOpID = 0;

for (auto f : module.getOps<func::FuncOp>()) {
f.walk([&](Operation *op) {
if (air::isAsyncOp(op)) {
assignOpId(op);
updateAsyncExecuteGraphWithNewNode(op, asyncExecuteGraph);
return; // Skip if is already async.
}
Expand All @@ -127,24 +129,22 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
if (isa<air::DmaMemcpyNdOp>(op))
createAsyncDMA(module_builder, op);
else if (isa<air::ChannelInterface>(op))
createAsyncChannel(module_builder, op, ChannelOpID);
createAsyncChannel(module_builder, op);
else if (isa<linalg::LinalgOp, func::CallOp, memref::DeallocOp,
memref::CopyOp>(op))
createAsyncExecute(module_builder, op, ExecuteOpID);
createAsyncExecute(module_builder, op);
else if (isa<memref::CastOp, affine::AffineApplyOp, arith::AddIOp,
arith::MulIOp>(op))
createAsyncExecute(module_builder, op, ExecuteOpID,
op->getResult(0).getType());
createAsyncExecute(module_builder, op, op->getResult(0).getType());
else if (auto hierarchy_op = dyn_cast<air::HierarchyInterface>(op))
createAsyncHierarchyImpls(module_builder, hierarchy_op,
HierarchyOpID);
createAsyncHierarchyImpls(module_builder, hierarchy_op);
// Create async execute region for memref.alloc
else if (auto memalloc_op = dyn_cast<memref::AllocOp>(op)) {
// Alloc can be used to specify shapes for operations such
// as reshape ops. If this alloc is used to specify shape of
// a reshap op, ignore this operation.
if (!alloc_for_reshape(memalloc_op->getOpResult(0)))
createAsyncExecute(module_builder, op, ExecuteOpID,
createAsyncExecute(module_builder, op,
memalloc_op.getMemref().getType());
}

Expand Down Expand Up @@ -187,10 +187,10 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
}
if (isCandidateExecute) {
if (op->getNumResults())
createAsyncExecute(module_builder, op, ExecuteOpID,
createAsyncExecute(module_builder, op,
op->getResults().front().getType());
else
createAsyncExecute(module_builder, op, ExecuteOpID);
createAsyncExecute(module_builder, op);
}
}
});
Expand Down Expand Up @@ -379,6 +379,7 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
uint64_t HierarchyOpID;
uint64_t WaitAllOpID;
uint64_t ChannelOpID;
uint64_t DmaOpID;

//===----------------------------------------------------------------------===//
// Handling lingering reshape-related ops
Expand Down Expand Up @@ -409,17 +410,14 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {

// Create air execute op with async interface (no ssa result returned); update
// graph
air::ExecuteOp createAsyncExecute(OpBuilder &builder, Operation *op,
uint64_t &ExecuteOpID) {
air::ExecuteOp createAsyncExecute(OpBuilder &builder, Operation *op) {
builder.setInsertionPoint(op);
auto loc = op->getLoc();
SmallVector<Value, 1> deps;
air::ExecuteOp async_region;
async_region = builder.create<air::ExecuteOp>(
loc, air::AsyncTokenType::get(op->getContext()), deps);
async_region->setAttr(
"id", mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32), ++ExecuteOpID));
assignOpId(async_region);

// Insert op to the new async execute region's body.
Block *async_region_bb = builder.createBlock(&async_region.getRegion());
Expand Down Expand Up @@ -459,7 +457,6 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
// Create air execute op with async interface (with one ssa result returned);
// update graph
air::ExecuteOp createAsyncExecute(OpBuilder &builder, Operation *op,
uint64_t &ExecuteOpID,
mlir::Type valueType) {
builder.setInsertionPoint(op);
auto loc = op->getLoc();
Expand All @@ -468,9 +465,7 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
async_region = builder.create<air::ExecuteOp>(
loc, air::AsyncTokenType::get(op->getContext()),
op->getResults().getType(), deps);
async_region->setAttr(
"id", mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32), ++ExecuteOpID));
assignOpId(async_region);

// Insert op to the new async execute region's body.
Block *async_region_bb = builder.createBlock(&async_region.getRegion());
Expand Down Expand Up @@ -500,15 +495,13 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
auto loc = op->getLoc();
SmallVector<Value, 1> deps;
auto dma_op = mlir::dyn_cast<air::DmaMemcpyNdOp>(op);
unsigned id = dma_op.getId();
// unsigned id = dma_op.getId();
air::DmaMemcpyNdOp new_dmaNd_op = builder.create<air::DmaMemcpyNdOp>(
loc, air::AsyncTokenType::get(dma_op->getContext()), deps,
dma_op.getDstMemref(), dma_op.getDstOffsets(), dma_op.getDstSizes(),
dma_op.getDstStrides(), dma_op.getSrcMemref(), dma_op.getSrcOffsets(),
dma_op.getSrcSizes(), dma_op.getSrcStrides());
new_dmaNd_op->setAttr(
"id", mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32), id));
assignOpId(new_dmaNd_op);

// Update op-to-graph map
updateAsyncExecuteGraphWithNewNode(new_dmaNd_op, asyncExecuteGraph);
Expand All @@ -520,8 +513,7 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
}

// Re-instantiate the channel op with async interface; update graph
void createAsyncChannel(OpBuilder &builder, Operation *op,
uint64_t &ChannelOpID) {
void createAsyncChannel(OpBuilder &builder, Operation *op) {
builder.setInsertionPoint(op);
auto loc = op->getLoc();
SmallVector<Value, 1> deps;
Expand All @@ -532,10 +524,7 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
channel_put_op.getChanName(), channel_put_op.getIndices(),
channel_put_op.getSrc(), channel_put_op.getSrcOffsets(),
channel_put_op.getSrcSizes(), channel_put_op.getSrcStrides());
new_channel_put_op->setAttr(
"id",
mlir::IntegerAttr::get(mlir::IntegerType::get(op->getContext(), 32),
++ChannelOpID));
assignOpId(new_channel_put_op);
event_name = "Put";
// Update op-to-graph map
updateAsyncExecuteGraphWithNewNode(new_channel_put_op, asyncExecuteGraph);
Expand All @@ -545,10 +534,7 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
channel_get_op.getChanName(), channel_get_op.getIndices(),
channel_get_op.getDst(), channel_get_op.getDstOffsets(),
channel_get_op.getDstSizes(), channel_get_op.getDstStrides());
new_channel_get_op->setAttr(
"id",
mlir::IntegerAttr::get(mlir::IntegerType::get(op->getContext(), 32),
++ChannelOpID));
assignOpId(new_channel_get_op);
event_name = "Get";
// Update op-to-graph map
updateAsyncExecuteGraphWithNewNode(new_channel_get_op, asyncExecuteGraph);
Expand All @@ -562,9 +548,8 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
}

// Re-instantiate the hierarchy op with async interface; update graph
air::HierarchyInterface createAsyncHierarchyImpls(OpBuilder &builder,
air::HierarchyInterface op,
uint64_t &HierarchyOpID) {
air::HierarchyInterface
createAsyncHierarchyImpls(OpBuilder &builder, air::HierarchyInterface op) {
builder.setInsertionPoint(op);
SmallVector<Value, 1> deps;
SmallVector<Value, 4> args;
Expand All @@ -580,19 +565,19 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
Operation *new_op = nullptr;
if (auto launch = dyn_cast<air::LaunchOp>(op.getOperation())) {
auto new_launch = createAsyncHierarchy<air::LaunchOp>(
builder, launch, HierarchyOpID, deps, args, constants);
builder, launch, deps, args, constants);
new_op = new_launch.getOperation();
// Update op-to-graph map
updateAsyncExecuteGraphWithNewNode(new_launch, asyncExecuteGraph);
} else if (auto segment = dyn_cast<air::SegmentOp>(op.getOperation())) {
auto new_segment = createAsyncHierarchy<air::SegmentOp>(
builder, segment, HierarchyOpID, deps, args, constants);
builder, segment, deps, args, constants);
new_op = new_segment.getOperation();
// Update op-to-graph map
updateAsyncExecuteGraphWithNewNode(new_segment, asyncExecuteGraph);
} else if (auto herd = dyn_cast<air::HerdOp>(op.getOperation())) {
auto new_herd = createAsyncHierarchy<air::HerdOp>(
builder, herd, HierarchyOpID, deps, args, constants);
auto new_herd = createAsyncHierarchy<air::HerdOp>(builder, herd, deps,
args, constants);
new_op = new_herd.getOperation();
// Update op-to-graph map
updateAsyncExecuteGraphWithNewNode(new_herd, asyncExecuteGraph);
Expand All @@ -609,15 +594,13 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
}

template <typename T>
T createAsyncHierarchy(OpBuilder &builder, T op, uint64_t &OpID,
SmallVector<Value, 1> deps, SmallVector<Value, 4> args,
T createAsyncHierarchy(OpBuilder &builder, T op, SmallVector<Value, 1> deps,
SmallVector<Value, 4> args,
SmallVector<Value, 4> constants) {
auto loc = op->getLoc();
T new_op = builder.create<T>(loc, deps, op.getSizeOperands(), args, true,
op->getAttrs());
new_op->setAttr("id",
mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32), ++OpID));
assignOpId(new_op);

auto &bb = new_op.getBody().front();
for (unsigned i = 0; i < op.getIds().size(); i++) {
Expand Down Expand Up @@ -683,6 +666,29 @@ class AIRDependency : public air::impl::AIRDependencyBase<AIRDependency> {
setNodeAttrsBasedOnOp(op);
}

void assignOpId(Operation *op) {
if (isa<air::ExecuteOp>(op))
op->setAttr("id", mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32),
++ExecuteOpID));
else if (isa<air::ChannelInterface>(op))
op->setAttr("id", mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32),
++ChannelOpID));
else if (isa<air::HierarchyInterface>(op))
op->setAttr("id", mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32),
++HierarchyOpID));
else if (isa<air::WaitAllOp>(op))
op->setAttr("id", mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32),
++WaitAllOpID));
else if (isa<air::DmaMemcpyNdOp>(op))
op->setAttr("id",
mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32), ++DmaOpID));
}

//===----------------------------------------------------------------------===//
// Data dependency tracing
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 8aa8318

Please sign in to comment.