From b332bc09a16d0dfa7090fd510d9885d7d82e56ae Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Mon, 11 Mar 2024 13:40:07 -0700 Subject: [PATCH] Ensure that bcast specialization only propagates through single-operand affine.apply ops (#485) --- .../Transform/AIRDependencyScheduleOpt.cpp | 2 +- mlir/lib/Transform/AIRMiscPasses.cpp | 21 +++++++------------ 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 16fde8750..63276a3e5 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -2150,7 +2150,6 @@ struct BroadcastDetection { air::HerdOp hl_op = nullptr; bool isVariantWrtHerdRows = false; bool isVariantWrtHerdCols = false; - // Create an affine set to represent the broadcast pattern auto ctx = dma_op->getContext(); for (auto v : loop_dep_history) { // Check row-wise or col-wise broadcastable based on variance wrt herd @@ -2183,6 +2182,7 @@ struct BroadcastDetection { isVariantWrtHerdCols = true; } + // Create an affine set to represent the broadcast pattern if (hl_op && isVariantWrtHerdRows && !isVariantWrtHerdCols) { auto numColsOp = dyn_cast( hl_op.getSizeOperands()[1].getDefiningOp()); diff --git a/mlir/lib/Transform/AIRMiscPasses.cpp b/mlir/lib/Transform/AIRMiscPasses.cpp index 159db6490..db6bf6965 100644 --- a/mlir/lib/Transform/AIRMiscPasses.cpp +++ b/mlir/lib/Transform/AIRMiscPasses.cpp @@ -516,20 +516,13 @@ class AIRSpecializeDmaBroadcast // operations in op history, last-in-first-out for (std::vector::reverse_iterator i = op_history.rbegin(); i != op_history.rend(); ++i) { - if (auto air_region_op = dyn_cast(*i)) { - assert(air_region_op.getBody().front().getOperations().size() == - 2 && - "air::ExecuteOp should have only one child operation beside " - "the terminator"); - // Get current scalar op - Operation *op = nullptr; - for (auto &child_op : - air_region_op.getBody().front().getOperations()) { - if (!dyn_cast(child_op)) - op = &child_op; - } + if (auto exec_op = dyn_cast(*i)) { + Operation *op = exec_op.getChildOp(); // If the async op is affine.apply if (auto apply_op = dyn_cast(op)) { + // Can only propagate affine.apply ops with single operand. + if (apply_op.getNumOperands() != 1) + return; auto map = apply_op.getAffineMap(); for (unsigned j = 0; j < current_shape_expr.size(); j++) { if (current_shape_expr[j]) { @@ -538,8 +531,8 @@ class AIRSpecializeDmaBroadcast // Remove dependence from scalar op to memcpyOp if present auto async_memcpyOp = dyn_cast(memcpyOp.getOperation()); - eraseAsyncDependencyFromAsyncOp( - async_memcpyOp, air_region_op.getAsyncToken()); + eraseAsyncDependencyFromAsyncOp(async_memcpyOp, + exec_op.getAsyncToken()); } } }