diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index 8e1f0616e..24f08dabb 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -62,9 +62,6 @@ matchAndRewriteCopyOp(memref::CopyOp op, RewriterBase &rewriter) { (dst_type.getMemorySpaceAsInt() == (int)air::MemorySpace::L3)) return failure(); - if (!(src_type.hasStaticShape() || dst_type.hasStaticShape())) - return failure(); - SmallVector src_offsets, dst_offsets; SmallVector src_strides, dst_strides; SmallVector src_sizes, dst_sizes; @@ -73,7 +70,9 @@ matchAndRewriteCopyOp(memref::CopyOp op, RewriterBase &rewriter) { auto &strides) { auto subview_offsets = subview.getOffsets().begin(); auto static_offsets = subview.getStaticOffsets(); + auto subview_sizes = subview.getSizes().begin(); auto static_sizes = subview.getStaticSizes(); + auto subview_strides = subview.getStrides().begin(); auto static_strides = subview.getStaticStrides(); auto loc = subview.getLoc(); @@ -98,9 +97,15 @@ matchAndRewriteCopyOp(memref::CopyOp op, RewriterBase &rewriter) { offsets.push_back(*subview_offsets++); } for (auto s : static_sizes) - sizes.push_back(rewriter.create(loc, s)); + if (s >= 0) + sizes.push_back(rewriter.create(loc, s)); + else + sizes.push_back(*subview_sizes++); for (auto s : layout_strides) - strides.push_back(rewriter.create(loc, s)); + if (s >= 0) + strides.push_back(rewriter.create(loc, s)); + else + strides.push_back(*subview_strides++); }; if (auto subview = src.getDefiningOp()) {