From 8aa8318dee83eec7bbb01a3580cc71f065bbeb6c Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Fri, 7 Mar 2025 14:43:55 -0800 Subject: [PATCH] Rewrite the operation id assignment, so that input ops, both sync and async, get unique ids for the dependency graph --- mlir/lib/Transform/AIRDependency.cpp | 94 +++++++++++++++------------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Transform/AIRDependency.cpp b/mlir/lib/Transform/AIRDependency.cpp index b7f67a03f..73aa3497c 100644 --- a/mlir/lib/Transform/AIRDependency.cpp +++ b/mlir/lib/Transform/AIRDependency.cpp @@ -114,10 +114,12 @@ class AIRDependency : public air::impl::AIRDependencyBase { HierarchyOpID = 0; WaitAllOpID = 0; ChannelOpID = 0; + DmaOpID = 0; for (auto f : module.getOps()) { f.walk([&](Operation *op) { if (air::isAsyncOp(op)) { + assignOpId(op); updateAsyncExecuteGraphWithNewNode(op, asyncExecuteGraph); return; // Skip if is already async. } @@ -127,24 +129,22 @@ class AIRDependency : public air::impl::AIRDependencyBase { if (isa(op)) createAsyncDMA(module_builder, op); else if (isa(op)) - createAsyncChannel(module_builder, op, ChannelOpID); + createAsyncChannel(module_builder, op); else if (isa(op)) - createAsyncExecute(module_builder, op, ExecuteOpID); + createAsyncExecute(module_builder, op); else if (isa(op)) - createAsyncExecute(module_builder, op, ExecuteOpID, - op->getResult(0).getType()); + createAsyncExecute(module_builder, op, op->getResult(0).getType()); else if (auto hierarchy_op = dyn_cast(op)) - createAsyncHierarchyImpls(module_builder, hierarchy_op, - HierarchyOpID); + createAsyncHierarchyImpls(module_builder, hierarchy_op); // Create async execute region for memref.alloc else if (auto memalloc_op = dyn_cast(op)) { // Alloc can be used to specify shapes for operations such // as reshape ops. If this alloc is used to specify shape of // a reshap op, ignore this operation. if (!alloc_for_reshape(memalloc_op->getOpResult(0))) - createAsyncExecute(module_builder, op, ExecuteOpID, + createAsyncExecute(module_builder, op, memalloc_op.getMemref().getType()); } @@ -187,10 +187,10 @@ class AIRDependency : public air::impl::AIRDependencyBase { } if (isCandidateExecute) { if (op->getNumResults()) - createAsyncExecute(module_builder, op, ExecuteOpID, + createAsyncExecute(module_builder, op, op->getResults().front().getType()); else - createAsyncExecute(module_builder, op, ExecuteOpID); + createAsyncExecute(module_builder, op); } } }); @@ -379,6 +379,7 @@ class AIRDependency : public air::impl::AIRDependencyBase { uint64_t HierarchyOpID; uint64_t WaitAllOpID; uint64_t ChannelOpID; + uint64_t DmaOpID; //===----------------------------------------------------------------------===// // Handling lingering reshape-related ops @@ -409,17 +410,14 @@ class AIRDependency : public air::impl::AIRDependencyBase { // Create air execute op with async interface (no ssa result returned); update // graph - air::ExecuteOp createAsyncExecute(OpBuilder &builder, Operation *op, - uint64_t &ExecuteOpID) { + air::ExecuteOp createAsyncExecute(OpBuilder &builder, Operation *op) { builder.setInsertionPoint(op); auto loc = op->getLoc(); SmallVector deps; air::ExecuteOp async_region; async_region = builder.create( loc, air::AsyncTokenType::get(op->getContext()), deps); - async_region->setAttr( - "id", mlir::IntegerAttr::get( - mlir::IntegerType::get(op->getContext(), 32), ++ExecuteOpID)); + assignOpId(async_region); // Insert op to the new async execute region's body. Block *async_region_bb = builder.createBlock(&async_region.getRegion()); @@ -459,7 +457,6 @@ class AIRDependency : public air::impl::AIRDependencyBase { // Create air execute op with async interface (with one ssa result returned); // update graph air::ExecuteOp createAsyncExecute(OpBuilder &builder, Operation *op, - uint64_t &ExecuteOpID, mlir::Type valueType) { builder.setInsertionPoint(op); auto loc = op->getLoc(); @@ -468,9 +465,7 @@ class AIRDependency : public air::impl::AIRDependencyBase { async_region = builder.create( loc, air::AsyncTokenType::get(op->getContext()), op->getResults().getType(), deps); - async_region->setAttr( - "id", mlir::IntegerAttr::get( - mlir::IntegerType::get(op->getContext(), 32), ++ExecuteOpID)); + assignOpId(async_region); // Insert op to the new async execute region's body. Block *async_region_bb = builder.createBlock(&async_region.getRegion()); @@ -500,15 +495,13 @@ class AIRDependency : public air::impl::AIRDependencyBase { auto loc = op->getLoc(); SmallVector deps; auto dma_op = mlir::dyn_cast(op); - unsigned id = dma_op.getId(); + // unsigned id = dma_op.getId(); air::DmaMemcpyNdOp new_dmaNd_op = builder.create( loc, air::AsyncTokenType::get(dma_op->getContext()), deps, dma_op.getDstMemref(), dma_op.getDstOffsets(), dma_op.getDstSizes(), dma_op.getDstStrides(), dma_op.getSrcMemref(), dma_op.getSrcOffsets(), dma_op.getSrcSizes(), dma_op.getSrcStrides()); - new_dmaNd_op->setAttr( - "id", mlir::IntegerAttr::get( - mlir::IntegerType::get(op->getContext(), 32), id)); + assignOpId(new_dmaNd_op); // Update op-to-graph map updateAsyncExecuteGraphWithNewNode(new_dmaNd_op, asyncExecuteGraph); @@ -520,8 +513,7 @@ class AIRDependency : public air::impl::AIRDependencyBase { } // Re-instantiate the channel op with async interface; update graph - void createAsyncChannel(OpBuilder &builder, Operation *op, - uint64_t &ChannelOpID) { + void createAsyncChannel(OpBuilder &builder, Operation *op) { builder.setInsertionPoint(op); auto loc = op->getLoc(); SmallVector deps; @@ -532,10 +524,7 @@ class AIRDependency : public air::impl::AIRDependencyBase { channel_put_op.getChanName(), channel_put_op.getIndices(), channel_put_op.getSrc(), channel_put_op.getSrcOffsets(), channel_put_op.getSrcSizes(), channel_put_op.getSrcStrides()); - new_channel_put_op->setAttr( - "id", - mlir::IntegerAttr::get(mlir::IntegerType::get(op->getContext(), 32), - ++ChannelOpID)); + assignOpId(new_channel_put_op); event_name = "Put"; // Update op-to-graph map updateAsyncExecuteGraphWithNewNode(new_channel_put_op, asyncExecuteGraph); @@ -545,10 +534,7 @@ class AIRDependency : public air::impl::AIRDependencyBase { channel_get_op.getChanName(), channel_get_op.getIndices(), channel_get_op.getDst(), channel_get_op.getDstOffsets(), channel_get_op.getDstSizes(), channel_get_op.getDstStrides()); - new_channel_get_op->setAttr( - "id", - mlir::IntegerAttr::get(mlir::IntegerType::get(op->getContext(), 32), - ++ChannelOpID)); + assignOpId(new_channel_get_op); event_name = "Get"; // Update op-to-graph map updateAsyncExecuteGraphWithNewNode(new_channel_get_op, asyncExecuteGraph); @@ -562,9 +548,8 @@ class AIRDependency : public air::impl::AIRDependencyBase { } // Re-instantiate the hierarchy op with async interface; update graph - air::HierarchyInterface createAsyncHierarchyImpls(OpBuilder &builder, - air::HierarchyInterface op, - uint64_t &HierarchyOpID) { + air::HierarchyInterface + createAsyncHierarchyImpls(OpBuilder &builder, air::HierarchyInterface op) { builder.setInsertionPoint(op); SmallVector deps; SmallVector args; @@ -580,19 +565,19 @@ class AIRDependency : public air::impl::AIRDependencyBase { Operation *new_op = nullptr; if (auto launch = dyn_cast(op.getOperation())) { auto new_launch = createAsyncHierarchy( - builder, launch, HierarchyOpID, deps, args, constants); + builder, launch, deps, args, constants); new_op = new_launch.getOperation(); // Update op-to-graph map updateAsyncExecuteGraphWithNewNode(new_launch, asyncExecuteGraph); } else if (auto segment = dyn_cast(op.getOperation())) { auto new_segment = createAsyncHierarchy( - builder, segment, HierarchyOpID, deps, args, constants); + builder, segment, deps, args, constants); new_op = new_segment.getOperation(); // Update op-to-graph map updateAsyncExecuteGraphWithNewNode(new_segment, asyncExecuteGraph); } else if (auto herd = dyn_cast(op.getOperation())) { - auto new_herd = createAsyncHierarchy( - builder, herd, HierarchyOpID, deps, args, constants); + auto new_herd = createAsyncHierarchy(builder, herd, deps, + args, constants); new_op = new_herd.getOperation(); // Update op-to-graph map updateAsyncExecuteGraphWithNewNode(new_herd, asyncExecuteGraph); @@ -609,15 +594,13 @@ class AIRDependency : public air::impl::AIRDependencyBase { } template - T createAsyncHierarchy(OpBuilder &builder, T op, uint64_t &OpID, - SmallVector deps, SmallVector args, + T createAsyncHierarchy(OpBuilder &builder, T op, SmallVector deps, + SmallVector args, SmallVector constants) { auto loc = op->getLoc(); T new_op = builder.create(loc, deps, op.getSizeOperands(), args, true, op->getAttrs()); - new_op->setAttr("id", - mlir::IntegerAttr::get( - mlir::IntegerType::get(op->getContext(), 32), ++OpID)); + assignOpId(new_op); auto &bb = new_op.getBody().front(); for (unsigned i = 0; i < op.getIds().size(); i++) { @@ -683,6 +666,29 @@ class AIRDependency : public air::impl::AIRDependencyBase { setNodeAttrsBasedOnOp(op); } + void assignOpId(Operation *op) { + if (isa(op)) + op->setAttr("id", mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), + ++ExecuteOpID)); + else if (isa(op)) + op->setAttr("id", mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), + ++ChannelOpID)); + else if (isa(op)) + op->setAttr("id", mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), + ++HierarchyOpID)); + else if (isa(op)) + op->setAttr("id", mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), + ++WaitAllOpID)); + else if (isa(op)) + op->setAttr("id", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), ++DmaOpID)); + } + //===----------------------------------------------------------------------===// // Data dependency tracing //===----------------------------------------------------------------------===//