Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ObjectFifo] Create a pass to convert temporary alloc to amdaie.buffer #783

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"

#define DEBUG_TYPE "iree-amdaie-temporary-alloc-bufferization"

namespace mlir::iree_compiler::AMDAIE {

namespace {

static std::optional<BufferOp> createBufferForTemporaryAllocOp(
IRRewriter &rewriter, WorkgroupOp workgroupOp, memref::AllocOp allocOp,
CoreOp coreOp, unsigned index) {
OpBuilder::InsertionGuard g(rewriter);
TileOp tileOp = coreOp.getTileOp();
// Reset rewriter's location to after last tile's declaration.
auto tiles = workgroupOp.getBody()->getOps<TileOp>();
assert(!tiles.empty() && "no tiles in workgroupOp");
rewriter.setInsertionPointAfter(*std::prev(tiles.end(), 1));
auto bufferType = cast<MemRefType>(allocOp.getType());
auto bufferOp = rewriter.create<AMDAIE::BufferOp>(
rewriter.getUnknownLoc(), bufferType, tileOp, nullptr);
return bufferOp;
}

static LogicalResult bufferizeTemporaryAllocInCoreOp(
IRRewriter &rewriter, WorkgroupOp workgroupOp, CoreOp coreOp,
SmallVector<Operation *> &toBeErased) {
// Step 1. Get all buffers within a CoreOp.
SmallVector<memref::AllocOp> allocOps;
coreOp.walk([&](Operation *op) {
if (auto allocOp = dyn_cast<memref::AllocOp>(op)) {
allocOps.push_back(allocOp);
toBeErased.push_back(allocOp);
} else if (auto deallocOp = dyn_cast<memref::DeallocOp>(op)) {
toBeErased.push_back(deallocOp);
}
});
// Bail out early in case of no temporary buffers.
if (allocOps.size() == 0) return success();
// Step 2. Traverse unique allocOps and create an aie.buffer for them.
SmallVector<BufferOp> temporaryBuffers;
unsigned tempBufferIndex = 0;
for (memref::AllocOp allocOp : allocOps) {
std::optional<BufferOp> temporaryBuffer = createBufferForTemporaryAllocOp(
rewriter, workgroupOp, allocOp, coreOp, tempBufferIndex++);
if (!temporaryBuffer) {
return failure();
}
allocOp.replaceAllUsesWith(temporaryBuffer.value().getResult());
}
return success();
}

class AMDAIETemporaryAllocBufferizationPass
: public impl::AMDAIETemporaryAllocBufferizationBase<
AMDAIETemporaryAllocBufferizationPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

void runOnOperation() override;
};

void AMDAIETemporaryAllocBufferizationPass::runOnOperation() {
Operation *parentOp = getOperation();
IRRewriter rewriter(&getContext());

SmallVector<Operation *> toBeErased;
WalkResult res = parentOp->walk([&](WorkgroupOp workgroupOp) {
for (CoreOp coreOp : workgroupOp.getOps<CoreOp>()) {
if (failed(bufferizeTemporaryAllocInCoreOp(rewriter, workgroupOp, coreOp,
toBeErased)))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();

for (Operation *op : toBeErased) {
op->dropAllUses();
rewriter.eraseOp(op);
}
}

} // namespace

std::unique_ptr<Pass> createAMDAIETemporaryAllocBufferizationPass() {
return std::make_unique<AMDAIETemporaryAllocBufferizationPass>();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ iree_cc_library(
"AMDAIERemoveMemorySpace.cpp"
"AMDAIESinkIntoCore.cpp"
"AMDAIESplitLogicalObjFifosForConnectionReuse.cpp"
"AMDAIETemporaryAllocBufferization.cpp"
"AMDAIETile.cpp"
"AMDAIETileAndFuse.cpp"
"AMDAIEUtils.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIEREMOVEMEMORYSPACE
#define GEN_PASS_DEF_AMDAIESINKINTOCORE
#define GEN_PASS_DEF_AMDAIESPLITLOGICALOBJFIFOSFORCONNECTIONREUSE
#define GEN_PASS_DEF_AMDAIETEMPORARYALLOCBUFFERIZATION
#define GEN_PASS_DEF_AMDAIETILE
#define GEN_PASS_DEF_AMDAIETILEANDFUSE
#define GEN_PASS_DEF_AMDAIEVECTORIZATION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ std::unique_ptr<Pass> createAMDAIESinkIntoCorePass();
/// Create a pass to split logicalobjectfifos for connection reuse.
std::unique_ptr<Pass> createAMDAIESplitLogicalObjFifosForConnectionReusePass();

/// Create a pass to bufferize temporary alloc ops.
std::unique_ptr<Pass> createAMDAIETemporaryAllocBufferizationPass();

/// Create pass to tile TilingInterface operations.
std::unique_ptr<Pass> createAMDAIETilePass(AMDAIETileOptions options = {});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ def AMDAIESplitLogicalObjFifosForConnectionReuse :
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIESplitLogicalObjFifosForConnectionReusePass()";
}

def AMDAIETemporaryAllocBufferization :
Pass<"iree-amdaie-temporary-alloc-bufferization", ""> {
let summary = "Bufferizes temporary alloc buffers into `amdaie.buffer` ops.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIETemporaryAllocBufferizationPass()";
}

def AMDAIETile :
InterfacePass<"iree-amdaie-tile", "mlir::FunctionOpInterface"> {
let summary = "Pass to tile TilingInterface operations.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ iree_lit_test_suite(
"remove_memory_space.mlir"
"sink_into_core.mlir"
"split_logicalobjfifos_for_connection_reuse.mlir"
"temporary_alloc_bufferization.mlir"
"tile_and_fuse_using_scf_for.mlir"
"tile_and_fuse_matmul_using_scf_forall.mlir"
"tile_and_fuse_convolution_using_scf_forall.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-temporary-alloc-bufferization)" --verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @temp_buffer
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[TILE_0_2:.*]] = amdaie.tile(%[[C0]], %[[C2]])
// CHECK-DAG: %[[TILE_0_3:.*]] = amdaie.tile(%[[C0]], %[[C3]])
// CHECK-DAG: %[[TILE_1_2:.*]] = amdaie.tile(%[[C1]], %[[C2]])
// CHECK-DAG: %[[BUFFER_1_2_0:.*]] = amdaie.buffer(%[[TILE_1_2]]) : memref<1024xf32, 2 : i32>
// CHECK-DAG: %[[BUFFER_1_2_1:.*]] = amdaie.buffer(%[[TILE_1_2]]) : memref<1024xf32, 2 : i32>
// CHECK-DAG: %[[BUFFER_0_3_0:.*]] = amdaie.buffer(%[[TILE_0_3]]) : memref<1024xf32, 2 : i32>
// CHECK-DAG: %[[BUFFER_0_3_1:.*]] = amdaie.buffer(%[[TILE_0_3]]) : memref<1024xf32, 2 : i32>
// CHECK-DAG: %[[BUFFER_0_2:.*]] = amdaie.buffer(%[[TILE_0_2]]) : memref<1024xf32, 2 : i32>
// CHECK: amdaie.core(%[[TILE_0_2]]
// CHECK-NOT: memref.alloc
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_0_2]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]]
// CHECK-NOT: dealloc
// CHECK: amdaie.end
// CHECK: amdaie.core(%[[TILE_0_3]]
// CHECK-NOT: memref.alloc
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_0_3_1]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]]
// CHECK-NOT: dealloc
// CHECK-NOT: memref.alloc
// CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[BUFFER_0_3_0]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST_1]]
// CHECK-NOT: dealloc
// CHECK: amdaie.end
// CHECK: amdaie.core(%[[TILE_1_2]]
// CHECK-NOT: memref.alloc
// CHECK-NOT: memref.alloc
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_1_2_1]]
// CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[BUFFER_1_2_0]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]]
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST_1]]
// CHECK-NOT: dealloc
// CHECK-NOT: dealloc
// CHECK: amdaie.end
func.func @temp_buffer() {
amdaie.workgroup {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%tile_0_2 = amdaie.tile(%c0, %c2)
%tile_0_3 = amdaie.tile(%c0, %c3)
%tile_1_2 = amdaie.tile(%c1, %c2)
%core_0_2 = amdaie.core(%tile_0_2, in : [], out : []) {
%cst_0 = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() : memref<1024xf32, 2 : i32>
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>)
memref.dealloc %alloc : memref<1024xf32, 2 : i32>
amdaie.end
}
%core_0_3 = amdaie.core(%tile_0_3, in : [], out : []) {
%cst_0 = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() : memref<1024xf32, 2 : i32>
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>)
memref.dealloc %alloc : memref<1024xf32, 2 : i32>
%alloc_1 = memref.alloc() : memref<1024xf32, 2 : i32>
%reinterpret_cast_1 = memref.reinterpret_cast %alloc_1 to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast_1 : memref<1x1x8x8x4x4xf32, 2 : i32>)
memref.dealloc %alloc_1 : memref<1024xf32, 2 : i32>
amdaie.end
}
%core_1_2 = amdaie.core(%tile_1_2, in : [], out : []) {
%cst_0 = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() : memref<1024xf32, 2 : i32>
%alloc_1 = memref.alloc() : memref<1024xf32, 2 : i32>
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
%reinterpret_cast_1 = memref.reinterpret_cast %alloc_1 to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32>
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>)
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast_1 : memref<1x1x8x8x4x4xf32, 2 : i32>)
memref.dealloc %alloc : memref<1024xf32, 2 : i32>
memref.dealloc %alloc_1 : memref<1024xf32, 2 : i32>
amdaie.end
}
amdaie.controlcode {
amdaie.end
}
}
return
}
Loading