From 690ba0ec5286dade47f0c8f73f6c67cd75df3ab2 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 24 Jan 2024 16:53:48 -0800 Subject: [PATCH] Fold a chain of memref.subview/expand_shape/transpose ops to wraps and strides in air.dma (#395) * Add a pass at the end of -air-copy-to-dma to fold a chain of memref data rearrangement ops to wraps and strides * Update formats; clean ups; remove the memref op erase at the end * Revert a change done by mistake --- mlir/lib/Conversion/ConvertToAIRPass.cpp | 343 ++++++++++++++++++ .../condense_memref_ops_to_air_memcpy.mlir | 92 +++++ 2 files changed, 435 insertions(+) create mode 100644 mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index 2468eb1b5..8b41048ba 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -2318,6 +2318,273 @@ class ScfForallToLaunchConversion : public OpRewritePattern { // bool generateSegment; }; +/// Build a strided memref type by applying `permutationMap` tp `memRefType`. +static MemRefType inferTransposeResultType(MemRefType memRefType, + AffineMap permutationMap) { + auto rank = memRefType.getRank(); + auto originalSizes = memRefType.getShape(); + auto [originalStrides, offset] = getStridesAndOffset(memRefType); + assert(originalStrides.size() == static_cast(rank)); + + // Compute permuted sizes and strides. + SmallVector sizes(rank, 0); + SmallVector strides(rank, 1); + for (const auto &en : llvm::enumerate(permutationMap.getResults())) { + unsigned position = cast(en.value()).getPosition(); + sizes[en.index()] = originalSizes[position]; + strides[en.index()] = originalStrides[position]; + } + + return MemRefType::Builder(memRefType) + .setShape(sizes) + .setLayout( + StridedLayoutAttr::get(memRefType.getContext(), offset, strides)); +} + +static SmallVector extractStridesFromMemrefType(MemRefType memrefTy, + OpBuilder &builder) { + // get the strides and offsets from the memref type + SmallVector strides; + int64_t offset; + SmallVector layout_strides; + auto successStrides = getStridesAndOffset(memrefTy, layout_strides, offset); + if (failed(successStrides)) { + llvm::outs() << "Failed to get strides\n"; + return strides; + } + + for (auto s : layout_strides) + strides.push_back( + builder.create(builder.getUnknownLoc(), s)); + + return strides; +} + +static SmallVector extractSizesFromMemrefType(MemRefType memrefTy, + OpBuilder &builder) { + SmallVector sizes; + for (auto s : memrefTy.getShape()) + sizes.push_back( + builder.create(builder.getUnknownLoc(), s)); + return sizes; +} + +static void extractOffsetsFromSubview(memref::SubViewOp subview, + OpBuilder &builder, + SmallVector &offsets) { + auto subview_offsets = subview.getOffsets().begin(); + auto static_offsets = subview.getStaticOffsets(); + auto loc = subview.getLoc(); + + for (auto o : static_offsets) { + if (o >= 0) + offsets.push_back(builder.create(loc, o)); + else + offsets.push_back(*subview_offsets++); + } +} + +static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, + SmallVector &offsets, + SmallVector &sizes, + SmallVector &strides, + MemRefType memref) { + // Increase vector sizes up to memref size. When offsets, sizes and strides + // are all empty, then it implies that the whole memref is accessed in the + // default order. + int max_dim_size = + std::max(std::max(offsets.size(), sizes.size()), strides.size()); + if (max_dim_size && offsets.size() < memref.getRank()) { + for (unsigned i = offsets.size(); i < memref.getRank(); i++) { + offsets.insert(offsets.begin(), builder.create( + builder.getUnknownLoc(), 0)); + } + } + if (max_dim_size && sizes.size() < memref.getRank()) { + for (unsigned i = sizes.size(); i < memref.getRank(); i++) { + sizes.insert(sizes.begin(), builder.create( + builder.getUnknownLoc(), 1)); + } + } + int memref_size = 1; + for (auto size : memref.getShape()) + memref_size *= size; + if (max_dim_size && strides.size() < memref.getRank()) { + for (unsigned i = strides.size(); i < memref.getRank(); i++) { + strides.insert(strides.begin(), + builder.create( + builder.getUnknownLoc(), memref_size)); + } + } + + // Reduce highest dimensions if more than memref size + while (strides.size() > memref.getRank() && getConstantIntValue(strides[0]) && + *getConstantIntValue(strides[0]) == memref_size) { + strides.erase(strides.begin()); + } + while (sizes.size() > memref.getRank() && getConstantIntValue(sizes[0]) && + *getConstantIntValue(sizes[0]) == 1) { + sizes.erase(sizes.begin()); + } + while (offsets.size() > std::min(sizes.size(), strides.size()) && + getConstantIntValue(offsets[0]) && + *getConstantIntValue(offsets[0]) == 0) { + offsets.erase(offsets.begin()); + } + + if (offsets.size() != sizes.size() || sizes.size() != strides.size()) + return failure(); + + return success(); +} + +static LogicalResult condenseMemrefDataReorderingToAIRDma( + air::DmaMemcpyNdOp dmaOp, std::vector src_ancestor_memref_ops, + std::vector dst_ancestor_memref_ops) { + OpBuilder rewriter(dmaOp); + auto src = dmaOp.getSrcMemref(); + auto dst = dmaOp.getDstMemref(); + auto loc = dmaOp->getLoc(); + + // It must already be a memref + auto src_type = src.getType().dyn_cast(); + auto dst_type = dst.getType().dyn_cast(); + if (!src_type) + return failure(); + if (!(src_type.hasStaticShape() || dst_type.hasStaticShape())) + return failure(); + + // Revert the vector of memref ops, as it was built with push_back. + std::reverse(src_ancestor_memref_ops.begin(), src_ancestor_memref_ops.end()); + std::reverse(dst_ancestor_memref_ops.begin(), dst_ancestor_memref_ops.end()); + + SmallVector src_offsets, dst_offsets; + SmallVector src_strides, dst_strides; + SmallVector src_sizes, dst_sizes; + SmallVector empty; + + MemRefType src_memref_ty; + if (!src_ancestor_memref_ops.empty()) { + if (auto subviewOp = + dyn_cast(src_ancestor_memref_ops[0])) { + extractOffsetsFromSubview(subviewOp, rewriter, src_offsets); + src_memref_ty = subviewOp.getSourceType(); + src = subviewOp.getSource(); + } else if (auto transposeOp = + dyn_cast(src_ancestor_memref_ops[0])) { + src_memref_ty = transposeOp.getIn().getType().cast(); + src = transposeOp.getIn(); + } + } + MemRefType dst_memref_ty; + if (!dst_ancestor_memref_ops.empty()) { + if (auto subviewOp = + dyn_cast(dst_ancestor_memref_ops[0])) { + extractOffsetsFromSubview(subviewOp, rewriter, dst_offsets); + dst_memref_ty = subviewOp.getSourceType(); + dst = subviewOp.getSource(); + } else if (auto transposeOp = + dyn_cast(dst_ancestor_memref_ops[0])) { + dst_memref_ty = transposeOp.getIn().getType().cast(); + dst = transposeOp.getIn(); + } + } + + for (auto memrefOp : src_ancestor_memref_ops) { + if (auto transposeOp = dyn_cast(memrefOp)) { + src_memref_ty = + inferTransposeResultType(src_memref_ty, transposeOp.getPermutation()); + } else if (auto expandShapeOp = dyn_cast(memrefOp)) { + FailureOr compute_expand = + memref::ExpandShapeOp::computeExpandedType( + src_memref_ty, expandShapeOp.getResultType().getShape(), + expandShapeOp.getReassociationIndices()); + if (failed(compute_expand)) { + assert(false); + } else { + src_memref_ty = *compute_expand; + } + } else if (auto subviewOp = dyn_cast(memrefOp)) { + // Check if subview is rank reduced + if (subviewOp.getSourceType().getRank() > subviewOp.getType().getRank()) + src_memref_ty = + memref::SubViewOp::inferRankReducedResultType( + subviewOp.getType().getShape(), src_memref_ty, + subviewOp.getStaticOffsets(), subviewOp.getStaticSizes(), + subviewOp.getStaticStrides()) + .cast(); + else + src_memref_ty = + memref::SubViewOp::inferResultType( + src_memref_ty, subviewOp.getStaticOffsets(), + subviewOp.getStaticSizes(), subviewOp.getStaticStrides()) + .cast(); + } + } + + for (auto memrefOp : dst_ancestor_memref_ops) { + if (auto transposeOp = dyn_cast(memrefOp)) { + dst_memref_ty = + inferTransposeResultType(dst_memref_ty, transposeOp.getPermutation()); + } else if (auto expandShapeOp = dyn_cast(memrefOp)) { + FailureOr compute_expand = + memref::ExpandShapeOp::computeExpandedType( + dst_memref_ty, expandShapeOp.getResultType().getShape(), + expandShapeOp.getReassociationIndices()); + if (failed(compute_expand)) { + assert(false); + } else { + dst_memref_ty = *compute_expand; + } + } else if (auto subviewOp = dyn_cast(memrefOp)) { + if (subviewOp.getSourceType().getRank() > subviewOp.getType().getRank()) + dst_memref_ty = + memref::SubViewOp::inferRankReducedResultType( + subviewOp.getType().getShape(), dst_memref_ty, + subviewOp.getStaticOffsets(), subviewOp.getStaticSizes(), + subviewOp.getStaticStrides()) + .cast(); + else + dst_memref_ty = + memref::SubViewOp::inferResultType( + dst_memref_ty, subviewOp.getStaticOffsets(), + subviewOp.getStaticSizes(), subviewOp.getStaticStrides()) + .cast(); + } + } + + if (src_ancestor_memref_ops.size()) { + src_strides = extractStridesFromMemrefType(src_memref_ty, rewriter); + src_sizes = extractSizesFromMemrefType(src_memref_ty, rewriter); + } + if (dst_ancestor_memref_ops.size()) { + dst_strides = extractStridesFromMemrefType(dst_memref_ty, rewriter); + dst_sizes = extractSizesFromMemrefType(dst_memref_ty, rewriter); + } + + SmallVector deps; + SmallVector tys; + + if (failed(canonicalizeAIRDmaOperands(rewriter, src_offsets, src_sizes, + src_strides, + src.getType().cast())) || + failed(canonicalizeAIRDmaOperands(rewriter, dst_offsets, dst_sizes, + dst_strides, + dst.getType().cast()))) { + assert(false); + } + auto new_dma = rewriter.create( + loc, tys, deps, dst, dst_offsets, dst_sizes, dst_strides, src, + src_offsets, src_sizes, src_strides); + + assert(!new_dma.getSrcMemref().getDefiningOp()); + assert(!new_dma.getDstMemref().getDefiningOp()); + + dmaOp->erase(); + + return success(); +} + struct CopyToDmaPass : public air::impl::CopyToDmaBase { CopyToDmaPass() = default; @@ -2389,6 +2656,82 @@ struct CopyToDmaPass : public air::impl::CopyToDmaBase { LLVM_DEBUG(llvm::outs() << "output\n"); LLVM_DEBUG(module.print(llvm::outs())); + + // Condense memref data pattern reordering ops, including memref.subview, + // memref.tranpose and memref.expand_shape into air.dma_memcpy_nd op's + // offsets, sizes and strides fields. + auto scope = getOperation(); + std::vector, + std::vector>> + dma_ops; + + scope->walk([&](xilinx::air::DmaMemcpyNdOp dmaOp) { + bool src_condense = false; + if (auto src_defop = dmaOp.getSrcMemref().getDefiningOp()) { + src_condense |= isa(src_defop); + src_condense |= isa(src_defop); + src_condense |= isa(src_defop); + } + bool dst_condense = false; + if (auto dst_defop = dmaOp.getDstMemref().getDefiningOp()) { + dst_condense |= isa(dst_defop); + dst_condense |= isa(dst_defop); + dst_condense |= isa(dst_defop); + } + if (src_condense || dst_condense) { + // Fields in the tuple: (1) dma op, (2) list of memref ops producing the + // src memref, and (3) list of memref ops producing the dst memref. + std::tuple, + std::vector> + log_entry; + std::get<0>(log_entry) = dmaOp; + if (src_condense) { + Operation *ancestor = dmaOp.getSrcMemref().getDefiningOp(); + bool exit = false; + while (ancestor && !exit) { + if (auto transpose_anc = dyn_cast(ancestor)) { + std::get<1>(log_entry).push_back(ancestor); + ancestor = transpose_anc.getIn().getDefiningOp(); + } else if (auto expand_anc = + dyn_cast(ancestor)) { + std::get<1>(log_entry).push_back(ancestor); + ancestor = expand_anc.getSrc().getDefiningOp(); + } else if (auto subview_anc = + dyn_cast(ancestor)) { + std::get<1>(log_entry).push_back(ancestor); + ancestor = subview_anc.getSource().getDefiningOp(); + } else + exit = true; + } + } + if (dst_condense) { + Operation *ancestor = dmaOp.getDstMemref().getDefiningOp(); + bool exit = false; + while (ancestor && !exit) { + if (auto transpose_anc = dyn_cast(ancestor)) { + std::get<2>(log_entry).push_back(ancestor); + ancestor = transpose_anc.getIn().getDefiningOp(); + } else if (auto expand_anc = + dyn_cast(ancestor)) { + std::get<2>(log_entry).push_back(ancestor); + ancestor = expand_anc.getSrc().getDefiningOp(); + } else if (auto subview_anc = + dyn_cast(ancestor)) { + std::get<2>(log_entry).push_back(ancestor); + ancestor = subview_anc.getSource().getDefiningOp(); + } else + exit = true; + } + } + dma_ops.push_back(log_entry); + } + }); + for (auto dmaOp : dma_ops) { + if (failed(condenseMemrefDataReorderingToAIRDma( + std::get<0>(dmaOp), std::get<1>(dmaOp), std::get<2>(dmaOp)))) { + return signalPassFailure(); + } + } } }; diff --git a/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir b/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir new file mode 100644 index 000000000..da50694d6 --- /dev/null +++ b/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir @@ -0,0 +1,92 @@ +//===- condense_memref_ops_to_air_memcpy.mlir ------------------*- MLIR -*-===// +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +// RUN: air-opt %s -air-copy-to-dma -canonicalize -cse | FileCheck %s + +// CHECK: %[[CST128:.*]] = arith.constant 128 : index +// CHECK: %[[CST32:.*]] = arith.constant 32 : index +// CHECK: %[[CST8:.*]] = arith.constant 8 : index +// CHECK: %[[CST16:.*]] = arith.constant 16 : index +// CHECK: %[[CST0:.*]] = arith.constant 0 : index +// CHECK: %[[CST1:.*]] = arith.constant 1 : index +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%{{.*}}, %[[CST0]]] [%[[CST8]], %[[CST16]]] [%[[CST16]], %[[CST1]]]) : (memref<1x1x8x16xi32, 1>, memref<8x16xi32>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0]], %{{.*}}] [%[[CST16]], %[[CST16]]] [%[[CST32]], %[[CST1]]]) : (memref<1x1x16x16xi32, 1>, memref<16x32xi32>) +// CHECK: air.herd @herd_0 +// CHECK: %[[CST32_0:.*]] = arith.constant 32 : index +// CHECK: %[[CST4_0:.*]] = arith.constant 4 : index +// CHECK: %[[CST2_0:.*]] = arith.constant 2 : index +// CHECK: %[[CST1_0:.*]] = arith.constant 1 : index +// CHECK: %[[CST16_0:.*]] = arith.constant 16 : index +// CHECK: %[[CST64_0:.*]] = arith.constant 64 : index +// CHECK: %[[CST8_0:.*]] = arith.constant 8 : index +// CHECK: %[[CST128_0:.*]] = arith.constant 128 : index +// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%{{.*}}, %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST2_0]], %[[CST2_0]], %[[CST4_0]], %[[CST8_0]]] [%[[CST8_0]], %[[CST64_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x8x16xi32, 1>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0_0]], %{{.*}}, %[[CST0_0]], %[[CST0_0]]] [%[[CST2_0]], %[[CST2_0]], %[[CST8_0]], %[[CST8_0]]] [%[[CST8_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x16x16xi32, 1>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[%{{.*}}, %{{.*}}, %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST8_0]], %[[CST16_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]], %{{.*}}[%[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST4_0]], %[[CST2_0]], %[[CST8_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST32_0]], %[[CST8_0]], %[[CST64_0]], %[[CST1_0]]]) : (memref<1x1x8x16xi32, 1>, memref<1x1x2x2x4x8xi32, 2>) +// CHECK: air.herd_terminator +// CHECK: air.dma_memcpy_nd (%{{.*}}[%{{.*}}, %{{.*}}] [%[[CST8]], %[[CST16]]] [%[[CST32]], %[[CST1]]], %{{.*}}[%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] [%[[CST1]], %[[CST1]], %[[CST8]], %[[CST16]]] [%[[CST128]], %[[CST128]], %[[CST16]], %[[CST1]]]) : (memref<8x32xi32>, memref<1x1x8x16xi32, 1>) + +#map = affine_map<()[s0] -> (s0 * 8)> +#map1 = affine_map<()[s0] -> (s0 * 16)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> +module { + func.func @func0(%0 : memref<8x16xi32>, %1 : memref<16x32xi32>, %2 : memref<8x32xi32>) { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + air.launch (%arg0, %arg1) in (%arg2=%c1, %arg3=%c2) args(%arg4=%0, %arg5=%1, %arg6=%2) : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> { + air.segment @segment_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> { + %c1_0 = arith.constant 1 : index + %3 = affine.apply #map()[%arg7] + %4 = affine.apply #map1()[%arg8] + %subview = memref.subview %arg9[%3, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1], offset: ?>> + %subview_1 = memref.subview %arg10[0, %4] [16, 16] [1, 1] : memref<16x32xi32> to memref<16x16xi32, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %arg11[%3, %4] [8, 16] [1, 1] : memref<8x32xi32> to memref<8x16xi32, strided<[32, 1], offset: ?>> + %alloc = memref.alloc() : memref<1x1x8x16xi32, 1> + %transpose = memref.transpose %subview (d0, d1) -> (d0, d1) : memref<8x16xi32, strided<[16, 1], offset: ?>> to memref<8x16xi32, strided<[16, 1], offset: ?>> + air.dma_memcpy_nd (%alloc[] [] [], %transpose[] [] []) : (memref<1x1x8x16xi32, 1>, memref<8x16xi32, strided<[16, 1], offset: ?>>) + %alloc_3 = memref.alloc() : memref<1x1x16x16xi32, 1> + %transpose_4 = memref.transpose %subview_1 (d0, d1) -> (d0, d1) : memref<16x16xi32, strided<[32, 1], offset: ?>> to memref<16x16xi32, strided<[32, 1], offset: ?>> + air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_4[] [] []) : (memref<1x1x16x16xi32, 1>, memref<16x16xi32, strided<[32, 1], offset: ?>>) + %alloc_5 = memref.alloc() : memref<1x1x8x16xi32, 1> + air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1_0, %arg15=%c1_0) args(%arg16=%alloc, %arg17=%alloc_3, %arg18=%alloc_5) : memref<1x1x8x16xi32, 1>, memref<1x1x16x16xi32, 1>, memref<1x1x8x16xi32, 1> { + %c0_i32 = arith.constant 0 : i32 + %subview_8 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> + %subview_9 = memref.subview %arg17[0, %arg13, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<1x1x16x16xi32, 1> to memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1> + %subview_10 = memref.subview %arg18[%arg12, %arg13, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> + %alloc_11 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2> + %expand_shape = memref.expand_shape %subview_8 [[0], [1], [2, 3], [4, 5]] : memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> into memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1> + %transpose_12 = memref.transpose %expand_shape (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1> + air.dma_memcpy_nd (%alloc_11[] [] [], %transpose_12[] [] []) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1>) + %alloc_13 = memref.alloc() : memref<1x1x2x2x8x8xi32, 2> + %expand_shape_14 = memref.expand_shape %subview_9 [[0], [1], [2, 3], [4, 5]] : memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1> into memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1> + %transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1> + air.dma_memcpy_nd (%alloc_13[] [] [], %transpose_15[] [] []) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1>) + %alloc_16 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2> + %transpose_17 = memref.transpose %alloc_16 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x2x2x4x8xi32, 2> to memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2> + air.dma_memcpy_nd (%subview_10[] [] [], %transpose_17[] [] []) : (memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>, memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2>) + memref.dealloc %alloc_11 : memref<1x1x2x2x4x8xi32, 2> + memref.dealloc %alloc_13 : memref<1x1x2x2x8x8xi32, 2> + memref.dealloc %alloc_16 : memref<1x1x2x2x4x8xi32, 2> + air.herd_terminator + } + %subview_6 = memref.subview %alloc_5[0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<8x16xi32, 1> + %transpose_7 = memref.transpose %subview_6 (d0, d1) -> (d0, d1) : memref<8x16xi32, 1> to memref<8x16xi32, strided<[16, 1]>, 1> + air.dma_memcpy_nd (%subview_2[] [] [], %transpose_7[] [] []) : (memref<8x16xi32, strided<[32, 1], offset: ?>>, memref<8x16xi32, strided<[16, 1]>, 1>) + memref.dealloc %alloc_3 : memref<1x1x16x16xi32, 1> + memref.dealloc %alloc : memref<1x1x8x16xi32, 1> + memref.dealloc %alloc_5 : memref<1x1x8x16xi32, 1> + air.segment_terminator + } + air.launch_terminator + } + return + } +}