From c4d79db207eac4cb178ebff8aae0492f405ba096 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Fri, 1 Mar 2024 12:26:29 -0800 Subject: [PATCH] Relax assumptions on dimensionality in air.dma as with optimizations the number of dimensions may be different than memref rank (#469) --- mlir/lib/Conversion/ConvertToAIRPass.cpp | 34 ------------------------ mlir/lib/Transform/AIRDependency.cpp | 4 +++ mlir/lib/Util/Dependency.cpp | 4 +++ 3 files changed, 8 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index 9673c63ee..34ba0e70f 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -1060,9 +1060,6 @@ class AIRDmaToAIRChannelConversion } else return failure(); - auto src_rank = src_type.getRank(); - auto dst_rank = dst_type.getRank(); - SmallVector src_offsets = op.getSrcOffsets(); SmallVector dst_offsets = op.getDstOffsets(); SmallVector src_sizes = op.getSrcSizes(); @@ -1070,20 +1067,6 @@ class AIRDmaToAIRChannelConversion SmallVector src_strides = op.getSrcStrides(); SmallVector dst_strides = op.getDstStrides(); - if (src_offsets.size()) { - if (src_sizes.size() != (unsigned)src_rank) - return failure(); - if (src_strides.size() != (unsigned)src_rank) - return failure(); - } - - if (dst_offsets.size()) { - if (dst_sizes.size() != (unsigned)dst_rank) - return failure(); - if (dst_strides.size() != (unsigned)dst_rank) - return failure(); - } - std::set erased; SmallVector externalGetPut; SmallVector internalGetPut; @@ -1472,9 +1455,6 @@ class AIRDemoteDmaToAIRHierarchyConversion return failure(); // This pass is currently not able to promote in memory // tier - auto src_rank = src_type.getRank(); - auto dst_rank = dst_type.getRank(); - SmallVector src_offsets = op.getSrcOffsets(); SmallVector dst_offsets = op.getDstOffsets(); SmallVector src_sizes = op.getSrcSizes(); @@ -1482,20 +1462,6 @@ class AIRDemoteDmaToAIRHierarchyConversion SmallVector src_strides = op.getSrcStrides(); SmallVector dst_strides = op.getDstStrides(); - if (src_offsets.size()) { - if (src_sizes.size() != (unsigned)src_rank) - return failure(); - if (src_strides.size() != (unsigned)src_rank) - return failure(); - } - - if (dst_offsets.size()) { - if (dst_sizes.size() != (unsigned)dst_rank) - return failure(); - if (dst_strides.size() != (unsigned)dst_rank) - return failure(); - } - std::set erased; { diff --git a/mlir/lib/Transform/AIRDependency.cpp b/mlir/lib/Transform/AIRDependency.cpp index bd91240ff..aff91299d 100644 --- a/mlir/lib/Transform/AIRDependency.cpp +++ b/mlir/lib/Transform/AIRDependency.cpp @@ -341,6 +341,7 @@ class AIRDependency for (unsigned i = 0; i < sink_op_memcpy.getSrcStrides().size(); i++) sink_op_scalar_ins.push_back(sink_op_memcpy.getSrcStrides()[i]); if (sink_op_memcpy.getSrcOffsets().size()) { + numDimsSrc = sink_op_memcpy.getSrcOffsets().size(); for (unsigned i = 0; i < numDimsSrc; i++) { src_indices.push_back(sink_op_memcpy.getSrcOffsets()[i]); } @@ -366,6 +367,7 @@ class AIRDependency for (unsigned i = 0; i < sink_op_memcpy.getDstStrides().size(); i++) sink_op_scalar_outs.push_back(sink_op_memcpy.getDstStrides()[i]); if (sink_op_memcpy.getDstOffsets().size()) { + numDimsDst = sink_op_memcpy.getDstOffsets().size(); for (unsigned i = 0; i < numDimsDst; i++) { dst_indices.push_back(sink_op_memcpy.getDstOffsets()[i]); } @@ -1018,6 +1020,7 @@ class AIRDependency unsigned numDimsSrc = memcpy.getSrcMemref().getType().cast().getRank(); if (memcpy.getSrcOffsets().size()) { + numDimsSrc = memcpy.getSrcOffsets().size(); for (unsigned i = 0; i < numDimsSrc; i++) { src_indices.push_back(memcpy.getSrcOffsets()[i]); } @@ -1034,6 +1037,7 @@ class AIRDependency memcpy.getDstMemref().getType().cast().getRank(); SmallVector dst_indices; if (memcpy.getDstOffsets().size()) { + numDimsDst = memcpy.getDstOffsets().size(); for (unsigned i = 0; i < numDimsDst; i++) { dst_indices.push_back(memcpy.getDstOffsets()[i]); } diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index 5bbd8ef8a..cd48b73ab 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -1818,6 +1818,7 @@ void dependencyTracer::pushDepsAtCurrentScope(mlir::Value operand, memcpy.getSrcMemref().getType().cast().getRank(); SmallVector src_indices; if (memcpy.getSrcOffsets().size()) { + numDimsSrc = memcpy.getSrcOffsets().size(); for (unsigned i = 0; i < numDimsSrc; i++) { src_indices.push_back(memcpy.getSrcOffsets()[i]); } @@ -1834,6 +1835,7 @@ void dependencyTracer::pushDepsAtCurrentScope(mlir::Value operand, memcpy.getDstMemref().getType().cast().getRank(); SmallVector dst_indices; if (memcpy.getDstOffsets().size()) { + numDimsDst = memcpy.getDstOffsets().size(); for (unsigned i = 0; i < numDimsDst; i++) { dst_indices.push_back(memcpy.getDstOffsets()[i]); } @@ -2034,6 +2036,7 @@ void dependencyTracer::getPartialMemrefFromOp( for (unsigned i = 0; i < sink_op_memcpy.getSrcStrides().size(); i++) sink_op_scalar_ins.push_back(sink_op_memcpy.getSrcStrides()[i]); if (sink_op_memcpy.getSrcOffsets().size()) { + numDimsSrc = sink_op_memcpy.getSrcOffsets().size(); for (unsigned i = 0; i < numDimsSrc; i++) { src_indices.push_back(sink_op_memcpy.getSrcOffsets()[i]); } @@ -2058,6 +2061,7 @@ void dependencyTracer::getPartialMemrefFromOp( for (unsigned i = 0; i < sink_op_memcpy.getDstStrides().size(); i++) sink_op_scalar_outs.push_back(sink_op_memcpy.getDstStrides()[i]); if (sink_op_memcpy.getDstOffsets().size()) { + numDimsDst = sink_op_memcpy.getDstOffsets().size(); for (unsigned i = 0; i < numDimsDst; i++) { dst_indices.push_back(sink_op_memcpy.getDstOffsets()[i]); }