Skip to content

Commit

Permalink
[ObjectFifo] Create a pass to convert temporary alloc to amdaie.buffer (
Browse files Browse the repository at this point in the history
#783)

-- This commit creates a pass
`iree-amdaie-temporary-alloc-bufferization`
   to convert temporary alloc/buffers to amdaie.buffer ops.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma authored Sep 18, 2024
1 parent e31b56a commit d27772e
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"

#define DEBUG_TYPE "iree-amdaie-temporary-alloc-bufferization"

namespace mlir::iree_compiler::AMDAIE {

namespace {

static std::optional<BufferOp> createBufferForTemporaryAllocOp(
IRRewriter &rewriter, WorkgroupOp workgroupOp, memref::AllocOp allocOp,
CoreOp coreOp, unsigned index) {
OpBuilder::InsertionGuard g(rewriter);
TileOp tileOp = coreOp.getTileOp();
// Reset rewriter's location to after last tile's declaration.
auto tiles = workgroupOp.getBody()->getOps<TileOp>();
assert(!tiles.empty() && "no tiles in workgroupOp");
rewriter.setInsertionPointAfter(*std::prev(tiles.end(), 1));
auto bufferType = cast<MemRefType>(allocOp.getType());
auto bufferOp = rewriter.create<AMDAIE::BufferOp>(
rewriter.getUnknownLoc(), bufferType, tileOp, nullptr);
return bufferOp;
}

static LogicalResult bufferizeTemporaryAllocInCoreOp(
IRRewriter &rewriter, WorkgroupOp workgroupOp, CoreOp coreOp,
SmallVector<Operation *> &toBeErased) {
// Step 1. Get all buffers within a CoreOp.
SmallVector<memref::AllocOp> allocOps;
coreOp.walk([&](Operation *op) {
if (auto allocOp = dyn_cast<memref::AllocOp>(op)) {
allocOps.push_back(allocOp);
toBeErased.push_back(allocOp);
} else if (auto deallocOp = dyn_cast<memref::DeallocOp>(op)) {
toBeErased.push_back(deallocOp);
}
});
// Bail out early in case of no temporary buffers.
if (allocOps.size() == 0) return success();
// Step 2. Traverse unique allocOps and create an aie.buffer for them.
SmallVector<BufferOp> temporaryBuffers;
unsigned tempBufferIndex = 0;
for (memref::AllocOp allocOp : allocOps) {
std::optional<BufferOp> temporaryBuffer = createBufferForTemporaryAllocOp(
rewriter, workgroupOp, allocOp, coreOp, tempBufferIndex++);
if (!temporaryBuffer) {
return failure();
}
allocOp.replaceAllUsesWith(temporaryBuffer.value().getResult());
}
return success();
}

class AMDAIETemporaryAllocBufferizationPass
: public impl::AMDAIETemporaryAllocBufferizationBase<
AMDAIETemporaryAllocBufferizationPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

void runOnOperation() override;
};

void AMDAIETemporaryAllocBufferizationPass::runOnOperation() {
Operation *parentOp = getOperation();
IRRewriter rewriter(&getContext());

SmallVector<Operation *> toBeErased;
WalkResult res = parentOp->walk([&](WorkgroupOp workgroupOp) {
for (CoreOp coreOp : workgroupOp.getOps<CoreOp>()) {
if (failed(bufferizeTemporaryAllocInCoreOp(rewriter, workgroupOp, coreOp,
toBeErased)))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();

for (Operation *op : toBeErased) {
op->dropAllUses();
rewriter.eraseOp(op);
}
}

} // namespace

std::unique_ptr<Pass> createAMDAIETemporaryAllocBufferizationPass() {
return std::make_unique<AMDAIETemporaryAllocBufferizationPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ iree_cc_library(
"AMDAIERemoveMemorySpace.cpp"
"AMDAIESinkIntoCore.cpp"
"AMDAIESplitLogicalObjFifosForConnectionReuse.cpp"
"AMDAIETemporaryAllocBufferization.cpp"
"AMDAIETile.cpp"
"AMDAIETileAndFuse.cpp"
"AMDAIEUtils.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEREMOVEMEMORYSPACE
#define GEN_PASS_DEF_AMDAIESINKINTOCORE
#define GEN_PASS_DEF_AMDAIESPLITLOGICALOBJFIFOSFORCONNECTIONREUSE
#define GEN_PASS_DEF_AMDAIETEMPORARYALLOCBUFFERIZATION
#define GEN_PASS_DEF_AMDAIETILE
#define GEN_PASS_DEF_AMDAIETILEANDFUSE
#define GEN_PASS_DEF_AMDAIEVECTORIZATION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ std::unique_ptr<Pass> createAMDAIESinkIntoCorePass();
/// Create a pass to split logicalobjectfifos for connection reuse.
std::unique_ptr<Pass> createAMDAIESplitLogicalObjFifosForConnectionReusePass();

/// Create a pass to bufferize temporary alloc ops.
std::unique_ptr<Pass> createAMDAIETemporaryAllocBufferizationPass();

/// Create pass to tile TilingInterface operations.
std::unique_ptr<Pass> createAMDAIETilePass(AMDAIETileOptions options = {});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ def AMDAIESplitLogicalObjFifosForConnectionReuse :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIESplitLogicalObjFifosForConnectionReusePass()";
}

def AMDAIETemporaryAllocBufferization :
Pass<"iree-amdaie-temporary-alloc-bufferization", ""> {
let summary = "Bufferizes temporary alloc buffers into `amdaie.buffer` ops.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIETemporaryAllocBufferizationPass()";
}

def AMDAIETile :
InterfacePass<"iree-amdaie-tile", "mlir::FunctionOpInterface"> {
let summary = "Pass to tile TilingInterface operations.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ iree_lit_test_suite(
"remove_memory_space.mlir"
"sink_into_core.mlir"
"split_logicalobjfifos_for_connection_reuse.mlir"
"temporary_alloc_bufferization.mlir"
"tile_and_fuse_using_scf_for.mlir"
"tile_and_fuse_matmul_using_scf_forall.mlir"
"tile_and_fuse_convolution_using_scf_forall.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-temporary-alloc-bufferization)" --verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @temp_buffer
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[TILE_0_2:.*]] = amdaie.tile(%[[C0]], %[[C2]])
// CHECK-DAG: %[[TILE_0_3:.*]] = amdaie.tile(%[[C0]], %[[C3]])
// CHECK-DAG: %[[TILE_1_2:.*]] = amdaie.tile(%[[C1]], %[[C2]])
// CHECK-DAG: %[[BUFFER_1_2_0:.*]] = amdaie.buffer(%[[TILE_1_2]]) : memref<1024xf32, 2 : i32>
// CHECK-DAG: %[[BUFFER_1_2_1:.*]] = amdaie.buffer(%[[TILE_1_2]]) : memref<1024xf32, 2 : i32>
// CHECK-DAG: %[[BUFFER_0_3_0:.*]] = amdaie.buffer(%[[TILE_0_3]]) : memref<1024xf32, 2 : i32>
// CHECK-DAG: %[[BUFFER_0_3_1:.*]] = amdaie.buffer(%[[TILE_0_3]]) : memref<1024xf32, 2 : i32>
// CHECK-DAG: %[[BUFFER_0_2:.*]] = amdaie.buffer(%[[TILE_0_2]]) : memref<1024xf32, 2 : i32>
// CHECK: amdaie.core(%[[TILE_0_2]]
// CHECK-NOT: memref.alloc
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_0_2]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]]
// CHECK-NOT: dealloc
// CHECK: amdaie.end
// CHECK: amdaie.core(%[[TILE_0_3]]
// CHECK-NOT: memref.alloc
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_0_3_1]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]]
// CHECK-NOT: dealloc
// CHECK-NOT: memref.alloc
// CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[BUFFER_0_3_0]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST_1]]
// CHECK-NOT: dealloc
// CHECK: amdaie.end
// CHECK: amdaie.core(%[[TILE_1_2]]
// CHECK-NOT: memref.alloc
// CHECK-NOT: memref.alloc
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_1_2_1]]
// CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[BUFFER_1_2_0]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST_1]]
// CHECK-NOT: dealloc
// CHECK-NOT: dealloc
// CHECK: amdaie.end
func.func @temp_buffer() {
amdaie.workgroup {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%tile_0_2 = amdaie.tile(%c0, %c2)
%tile_0_3 = amdaie.tile(%c0, %c3)
%tile_1_2 = amdaie.tile(%c1, %c2)
%core_0_2 = amdaie.core(%tile_0_2, in : [], out : []) {
%cst_0 = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() : memref<1024xf32, 2 : i32>
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>)
memref.dealloc %alloc : memref<1024xf32, 2 : i32>
amdaie.end
}
%core_0_3 = amdaie.core(%tile_0_3, in : [], out : []) {
%cst_0 = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() : memref<1024xf32, 2 : i32>
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>)
memref.dealloc %alloc : memref<1024xf32, 2 : i32>
%alloc_1 = memref.alloc() : memref<1024xf32, 2 : i32>
%reinterpret_cast_1 = memref.reinterpret_cast %alloc_1 to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast_1 : memref<1x1x8x8x4x4xf32, 2 : i32>)
memref.dealloc %alloc_1 : memref<1024xf32, 2 : i32>
amdaie.end
}
%core_1_2 = amdaie.core(%tile_1_2, in : [], out : []) {
%cst_0 = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() : memref<1024xf32, 2 : i32>
%alloc_1 = memref.alloc() : memref<1024xf32, 2 : i32>
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
%reinterpret_cast_1 = memref.reinterpret_cast %alloc_1 to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>)
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast_1 : memref<1x1x8x8x4x4xf32, 2 : i32>)
memref.dealloc %alloc : memref<1024xf32, 2 : i32>
memref.dealloc %alloc_1 : memref<1024xf32, 2 : i32>
amdaie.end
}
amdaie.controlcode {
amdaie.end
}
}
return
}

0 comments on commit d27772e

Please sign in to comment.