From 9503917b7f5e86c0b5fd73bdde3b8fa18d6604fa Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Thu, 7 Mar 2024 13:29:51 -0800 Subject: [PATCH] =?UTF-8?q?Fixup=20a=20bug=20when=20removing=20redundant?= =?UTF-8?q?=20dimensions;=20enable=20clearing=20of=20wr=E2=80=A6=20(#479)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixup a bug when removing redundant dimensions; enable clearing of wrap-and-stride lists if default access pattern * Rewrite success/failure conditions * Code format * Clang format --- mlir/include/air/Util/Util.h | 3 +- mlir/lib/Conversion/AIRRtToIpuPass.cpp | 19 ++++++- .../Transform/AIRDependencyScheduleOpt.cpp | 52 +++++++++++-------- mlir/lib/Util/Util.cpp | 45 ++++++++++------ .../specialize-channel-wrap-and-stride.mlir | 29 ++++++++--- 5 files changed, 101 insertions(+), 47 deletions(-) diff --git a/mlir/include/air/Util/Util.h b/mlir/include/air/Util/Util.h index bae0da09d..dac146185 100644 --- a/mlir/include/air/Util/Util.h +++ b/mlir/include/air/Util/Util.h @@ -144,7 +144,8 @@ void foldForLoopNestAsExtendedSizesAndStrides( LogicalResult canonicalizeWrapAndStrideList(OpBuilder builder, SmallVector &offsets, SmallVector &sizes, - SmallVector &strides); + SmallVector &strides, + int memref_volume); } // namespace air } // namespace xilinx diff --git a/mlir/lib/Conversion/AIRRtToIpuPass.cpp b/mlir/lib/Conversion/AIRRtToIpuPass.cpp index 3a77f146e..ef77171e2 100644 --- a/mlir/lib/Conversion/AIRRtToIpuPass.cpp +++ b/mlir/lib/Conversion/AIRRtToIpuPass.cpp @@ -667,8 +667,23 @@ specializeAffineForInAIRRtDmaWrapAndStride(OpBuilder builder, strides.push_back(i64_one); // Canonicalize wraps and strides - (void)air::canonicalizeWrapAndStrideList(builder, offsets, wraps, strides); - + (void)air::canonicalizeWrapAndStrideList( + builder, offsets, wraps, strides, air::getTensorVolume(memref.getType())); + + // If empty offsets/sizes/strides, then populate the lists with default + // values. + if (offsets.empty() && wraps.empty() && strides.empty()) { + auto memref_shape = air::getTensorShape(memref.getType()); + int current_stride = air::getTensorVolume(memref.getType()); + for (unsigned i = 0; i < memref_shape.size(); i++) { + offsets.push_back(builder.create(loc, 0)); + wraps.push_back( + builder.create(loc, memref_shape[i])); + current_stride /= memref_shape[i]; + strides.push_back( + builder.create(loc, current_stride)); + } + } xilinx::air::foldForLoopNestAsExtendedSizesAndStrides( builder, for_op.getOperation(), memcpy_ops[0].getOperation(), offsets, wraps, strides, memcpy_ops[0]->getOperand(3)); diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 64a10c413..16fde8750 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -1699,6 +1699,22 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor SmallVector offsets = channel_ops[0].getOffsets(); SmallVector wraps = channel_ops[0].getSizes(); SmallVector strides = channel_ops[0].getStrides(); + for (auto o : for_loops) { + // Check for perfect loop nest containing only air.channel ops + if (!hasNElements(o.getBody(), 1)) + return failure(); + if (isa(o.getBody()->begin())) { + } else if (isa(o.getBody()->begin())) { + } else + return failure(); + if (!getStaticScfForTripCountAsInt(o)) + return failure(); + } + + (void)canonicalizeWrapAndStrideList( + rewriter, offsets, wraps, strides, + air::getTensorVolume(channel_ops[0].getMemref().getType())); + // If empty offsets/sizes/strides, then populate the lists with default // values. if (offsets.empty() && wraps.empty() && strides.empty()) { @@ -1714,25 +1730,13 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor rewriter.create(loc, current_stride)); } } - for (auto o : for_loops) { - // Check for perfect loop nest containing only air.channel ops - if (!hasNElements(o.getBody(), 1)) - return failure(); - if (isa(o.getBody()->begin())) { - } else if (isa(o.getBody()->begin())) { - } else - return failure(); - if (!getStaticScfForTripCountAsInt(o)) - return failure(); - } - - (void)canonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides); - foldForLoopNestAsExtendedSizesAndStrides( rewriter, for_op.getOperation(), channel_ops[0].getOperation(), offsets, wraps, strides, channel_ops[0].getMemref()); - (void)canonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides); + (void)canonicalizeWrapAndStrideList( + rewriter, offsets, wraps, strides, + air::getTensorVolume(channel_ops[0].getMemref().getType())); Operation *new_chan_op = nullptr; SmallVector tys; @@ -1823,13 +1827,17 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor return failure(); } - (void)canonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides); + (void)canonicalizeWrapAndStrideList( + rewriter, offsets, wraps, strides, + air::getTensorVolume(channel_ops[0].getMemref().getType())); foldForLoopNestAsExtendedSizesAndStrides( rewriter, for_op.getOperation(), channel_ops[0].getOperation(), offsets, wraps, strides, channel_ops[0].getMemref()); - (void)canonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides); + (void)canonicalizeWrapAndStrideList( + rewriter, offsets, wraps, strides, + air::getTensorVolume(channel_ops[0].getMemref().getType())); Operation *new_chan_op = nullptr; SmallVector tys; @@ -1881,8 +1889,9 @@ struct AIRCanonicalizeChannelPutOpWrapAndStrideList SmallVector sizes = op.getSizes(); SmallVector strides = op.getStrides(); - if (failed( - canonicalizeWrapAndStrideList(rewriter, offsets, sizes, strides))) + if (failed(canonicalizeWrapAndStrideList( + rewriter, offsets, sizes, strides, + air::getTensorVolume(op.getMemref().getType())))) return failure(); auto new_op = rewriter.create( @@ -1917,8 +1926,9 @@ struct AIRCanonicalizeChannelGetOpWrapAndStrideList SmallVector sizes = op.getSizes(); SmallVector strides = op.getStrides(); - if (failed( - canonicalizeWrapAndStrideList(rewriter, offsets, sizes, strides))) + if (failed(canonicalizeWrapAndStrideList( + rewriter, offsets, sizes, strides, + air::getTensorVolume(op.getMemref().getType())))) return failure(); auto new_op = rewriter.create( diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index efc1e5dbb..f75155fa6 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -806,8 +806,10 @@ void air::getDefiningOpsToOperands(Operation *op, LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder, SmallVector &offsets, SmallVector &sizes, - SmallVector &strides) { + SmallVector &strides, + int memref_volume) { + bool listsHaveChanged = false; // Match offsets size with sizes and strides int max_dim_size = std::max(std::max(offsets.size(), sizes.size()), strides.size()); @@ -816,6 +818,7 @@ LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder, offsets.insert(offsets.begin(), builder.create( builder.getUnknownLoc(), 0)); } + listsHaveChanged = true; } SmallVector unit_dims; @@ -831,40 +834,50 @@ LogicalResult air::canonicalizeWrapAndStrideList(OpBuilder builder, offsets.erase(offsets.begin() + i); sizes.erase(sizes.begin() + i); strides.erase(strides.begin() + i); + listsHaveChanged = true; } - SmallVector redundant_dims; - if (!sizes.empty()) + if (!sizes.empty()) { for (int i = sizes.size() - 1; i >= 1; i--) { - if (getConstantIntValue(sizes[i]) && getConstantIntValue(sizes[i - 1]) && + if (getConstantIntValue(offsets[i]) && + getConstantIntValue(offsets[i - 1]) && + getConstantIntValue(sizes[i]) && getConstantIntValue(sizes[i - 1]) && getConstantIntValue(strides[i]) && getConstantIntValue(strides[i - 1])) { auto const_size = *getConstantIntValue(sizes[i]); auto const_size_next = *getConstantIntValue(sizes[i - 1]); auto const_stride = *getConstantIntValue(strides[i]); auto const_stride_next = *getConstantIntValue(strides[i - 1]); - // Skip over the first dimension if stride is 1 - if (const_stride == 1 && i == (int)sizes.size() - 1) - continue; if (const_stride_next == const_size * const_stride) { - redundant_dims.push_back(i - 1); sizes[i] = builder.create( builder.getUnknownLoc(), const_size * const_size_next); + offsets.erase(offsets.begin() + i - 1); + sizes.erase(sizes.begin() + i - 1); + strides.erase(strides.begin() + i - 1); + listsHaveChanged = true; } } } - - for (auto i : redundant_dims) { - offsets.erase(offsets.begin() + i); - sizes.erase(sizes.begin() + i); - strides.erase(strides.begin() + i); } - if (unit_dims.empty() && redundant_dims.empty()) { - return failure(); + // If default data access pattern, then clear the offsets, sizes and strides. + if (offsets.size() == 1 && sizes.size() == 1 && strides.size() == 1) { + if (getConstantIntValue(offsets[0]) && getConstantIntValue(sizes[0]) && + getConstantIntValue(strides[0])) { + if (*getConstantIntValue(strides[0]) == 1 && + *getConstantIntValue(sizes[0]) == memref_volume) { + offsets.erase(offsets.begin()); + sizes.erase(sizes.begin()); + strides.erase(strides.begin()); + listsHaveChanged = true; + } + } } - return success(); + if (listsHaveChanged) + return success(); + else + return failure(); } // Fold perfectly nested for loops as extra entries in wraps and strides diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir index 5de9b5fc6..e89790555 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir @@ -66,7 +66,7 @@ module { } // CHECK-LABEL: test1 - // CHECK: put @channel_1[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>) + // CHECK: put @channel_1[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>) // CHECK: get @channel_2[%c0, %c0] (%arg1[%c0, %c0] [%c128, %c32] [%c128, %c1]) : (memref<128x128xf32>) // CHECK: put @channel_3[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>) // CHECK: put @channel_4[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>) @@ -101,7 +101,7 @@ module { } // CHECK-LABEL: test2 - // CHECK: put @channel_6[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>) + // CHECK: put @channel_6[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>) // CHECK: get @channel_7[%c0, %c0] (%arg1[%c0, %c0] [%c128, %c32] [%c128, %c1]) : (memref<128x128xf32>) // CHECK: put @channel_8[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>) // CHECK: put @channel_9[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>) @@ -136,10 +136,10 @@ module { } // CHECK-LABEL: test3 - // CHECK: put @channel_11[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>) + // CHECK: put @channel_11[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>) // CHECK: put @channel_12[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>) - // CHECK: put @channel_13[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>) - // CHECK: put async [%0] @channel_14[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>) + // CHECK: put @channel_13[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>) + // CHECK: put async [%0] @channel_14[%c0, %c0] (%arg0[] [] []) : (memref<128xf32>) func.func @test3(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> { %c0 = arith.constant 0 : index @@ -177,7 +177,7 @@ module { // CHECK-LABEL: test4 // CHECK: put async @channel_15[%c0, %c0] (%arg0[%c0] [%c32] [%c1]) : (memref<128xf32>) - // CHECK: put async @channel_16[%c0, %c0] (%arg1[%c0, %c0] [%c16, %c4] [%c4, %c1]) : (memref<128x128xf32>) + // CHECK: put async @channel_16[%c0, %c0] (%arg1[%c0] [%c64] [%c1]) : (memref<128x128xf32>) func.func @test4(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> { %c0 = arith.constant 0 : index @@ -196,7 +196,7 @@ module { } // CHECK-LABEL: test5 - // CHECK: put async @channel_17[] (%arg0[%c0, %c0, %c0] [%c8, %c32, %c32] [%c0, %c32, %c1]) : (memref<32x32xf32>) + // CHECK: put async @channel_17[] (%arg0[%c0, %c0] [%c8, %c1024] [%c0, %c1]) : (memref<32x32xf32>) func.func @test5(%arg0: memref<32x32xf32>) -> memref<32x32xf32> { %c0 = arith.constant 0 : index @@ -209,4 +209,19 @@ module { } return %alloc : memref<32x32xf32> } + + // CHECK-LABEL: test6 + // CHECK: put async @channel_18[] (%arg0[] [] []) : (memref<1x1x4x2x8x4xi32>) + + func.func @test6(%arg0: memref<1x1x4x2x8x4xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %1 = air.channel.put async @channel_18[] (%arg0[%c0, %c0, %c0, %c0] [%c4, %c2, %c8, %c4] [%c64, %c32, %c4, %c1]) : (memref<1x1x4x2x8x4xi32>) + return + } }