-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ObjectFifo] Create a pass to convert temporary alloc to amdaie.buffer (
#783) -- This commit creates a pass `iree-amdaie-temporary-alloc-bufferization` to convert temporary alloc/buffers to amdaie.buffer ops. Signed-off-by: Abhishek Varma <[email protected]>
- Loading branch information
1 parent
e31b56a
commit d27772e
Showing
7 changed files
with
198 additions
and
0 deletions.
There are no files selected for viewing
98 changes: 98 additions & 0 deletions
98
...iler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIETemporaryAllocBufferization.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree-amd-aie/IR/AMDAIEOps.h" | ||
#include "iree-amd-aie/Transforms/Passes.h" | ||
|
||
#define DEBUG_TYPE "iree-amdaie-temporary-alloc-bufferization" | ||
|
||
namespace mlir::iree_compiler::AMDAIE { | ||
|
||
namespace { | ||
|
||
static std::optional<BufferOp> createBufferForTemporaryAllocOp( | ||
IRRewriter &rewriter, WorkgroupOp workgroupOp, memref::AllocOp allocOp, | ||
CoreOp coreOp, unsigned index) { | ||
OpBuilder::InsertionGuard g(rewriter); | ||
TileOp tileOp = coreOp.getTileOp(); | ||
// Reset rewriter's location to after last tile's declaration. | ||
auto tiles = workgroupOp.getBody()->getOps<TileOp>(); | ||
assert(!tiles.empty() && "no tiles in workgroupOp"); | ||
rewriter.setInsertionPointAfter(*std::prev(tiles.end(), 1)); | ||
auto bufferType = cast<MemRefType>(allocOp.getType()); | ||
auto bufferOp = rewriter.create<AMDAIE::BufferOp>( | ||
rewriter.getUnknownLoc(), bufferType, tileOp, nullptr); | ||
return bufferOp; | ||
} | ||
|
||
static LogicalResult bufferizeTemporaryAllocInCoreOp( | ||
IRRewriter &rewriter, WorkgroupOp workgroupOp, CoreOp coreOp, | ||
SmallVector<Operation *> &toBeErased) { | ||
// Step 1. Get all buffers within a CoreOp. | ||
SmallVector<memref::AllocOp> allocOps; | ||
coreOp.walk([&](Operation *op) { | ||
if (auto allocOp = dyn_cast<memref::AllocOp>(op)) { | ||
allocOps.push_back(allocOp); | ||
toBeErased.push_back(allocOp); | ||
} else if (auto deallocOp = dyn_cast<memref::DeallocOp>(op)) { | ||
toBeErased.push_back(deallocOp); | ||
} | ||
}); | ||
// Bail out early in case of no temporary buffers. | ||
if (allocOps.size() == 0) return success(); | ||
// Step 2. Traverse unique allocOps and create an aie.buffer for them. | ||
SmallVector<BufferOp> temporaryBuffers; | ||
unsigned tempBufferIndex = 0; | ||
for (memref::AllocOp allocOp : allocOps) { | ||
std::optional<BufferOp> temporaryBuffer = createBufferForTemporaryAllocOp( | ||
rewriter, workgroupOp, allocOp, coreOp, tempBufferIndex++); | ||
if (!temporaryBuffer) { | ||
return failure(); | ||
} | ||
allocOp.replaceAllUsesWith(temporaryBuffer.value().getResult()); | ||
} | ||
return success(); | ||
} | ||
|
||
class AMDAIETemporaryAllocBufferizationPass | ||
: public impl::AMDAIETemporaryAllocBufferizationBase< | ||
AMDAIETemporaryAllocBufferizationPass> { | ||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<AMDAIEDialect>(); | ||
} | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
void AMDAIETemporaryAllocBufferizationPass::runOnOperation() { | ||
Operation *parentOp = getOperation(); | ||
IRRewriter rewriter(&getContext()); | ||
|
||
SmallVector<Operation *> toBeErased; | ||
WalkResult res = parentOp->walk([&](WorkgroupOp workgroupOp) { | ||
for (CoreOp coreOp : workgroupOp.getOps<CoreOp>()) { | ||
if (failed(bufferizeTemporaryAllocInCoreOp(rewriter, workgroupOp, coreOp, | ||
toBeErased))) | ||
return WalkResult::interrupt(); | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
if (res.wasInterrupted()) return signalPassFailure(); | ||
|
||
for (Operation *op : toBeErased) { | ||
op->dropAllUses(); | ||
rewriter.eraseOp(op); | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> createAMDAIETemporaryAllocBufferizationPass() { | ||
return std::make_unique<AMDAIETemporaryAllocBufferizationPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler::AMDAIE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
88 changes: 88 additions & 0 deletions
88
...er/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/temporary_alloc_bufferization.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-temporary-alloc-bufferization)" --verify-diagnostics %s | FileCheck %s | ||
|
||
// CHECK-LABEL: @temp_buffer | ||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index | ||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index | ||
// CHECK-DAG: %[[TILE_0_2:.*]] = amdaie.tile(%[[C0]], %[[C2]]) | ||
// CHECK-DAG: %[[TILE_0_3:.*]] = amdaie.tile(%[[C0]], %[[C3]]) | ||
// CHECK-DAG: %[[TILE_1_2:.*]] = amdaie.tile(%[[C1]], %[[C2]]) | ||
// CHECK-DAG: %[[BUFFER_1_2_0:.*]] = amdaie.buffer(%[[TILE_1_2]]) : memref<1024xf32, 2 : i32> | ||
// CHECK-DAG: %[[BUFFER_1_2_1:.*]] = amdaie.buffer(%[[TILE_1_2]]) : memref<1024xf32, 2 : i32> | ||
// CHECK-DAG: %[[BUFFER_0_3_0:.*]] = amdaie.buffer(%[[TILE_0_3]]) : memref<1024xf32, 2 : i32> | ||
// CHECK-DAG: %[[BUFFER_0_3_1:.*]] = amdaie.buffer(%[[TILE_0_3]]) : memref<1024xf32, 2 : i32> | ||
// CHECK-DAG: %[[BUFFER_0_2:.*]] = amdaie.buffer(%[[TILE_0_2]]) : memref<1024xf32, 2 : i32> | ||
// CHECK: amdaie.core(%[[TILE_0_2]] | ||
// CHECK-NOT: memref.alloc | ||
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_0_2]] | ||
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]] | ||
// CHECK-NOT: dealloc | ||
// CHECK: amdaie.end | ||
// CHECK: amdaie.core(%[[TILE_0_3]] | ||
// CHECK-NOT: memref.alloc | ||
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_0_3_1]] | ||
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]] | ||
// CHECK-NOT: dealloc | ||
// CHECK-NOT: memref.alloc | ||
// CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[BUFFER_0_3_0]] | ||
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST_1]] | ||
// CHECK-NOT: dealloc | ||
// CHECK: amdaie.end | ||
// CHECK: amdaie.core(%[[TILE_1_2]] | ||
// CHECK-NOT: memref.alloc | ||
// CHECK-NOT: memref.alloc | ||
// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BUFFER_1_2_1]] | ||
// CHECK: %[[CAST_1:.*]] = memref.reinterpret_cast %[[BUFFER_1_2_0]] | ||
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST]] | ||
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[CAST_1]] | ||
// CHECK-NOT: dealloc | ||
// CHECK-NOT: dealloc | ||
// CHECK: amdaie.end | ||
func.func @temp_buffer() { | ||
amdaie.workgroup { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c2 = arith.constant 2 : index | ||
%c3 = arith.constant 3 : index | ||
%tile_0_2 = amdaie.tile(%c0, %c2) | ||
%tile_0_3 = amdaie.tile(%c0, %c3) | ||
%tile_1_2 = amdaie.tile(%c1, %c2) | ||
%core_0_2 = amdaie.core(%tile_0_2, in : [], out : []) { | ||
%cst_0 = arith.constant 0.000000e+00 : f32 | ||
%alloc = memref.alloc() : memref<1024xf32, 2 : i32> | ||
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32> | ||
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>) | ||
memref.dealloc %alloc : memref<1024xf32, 2 : i32> | ||
amdaie.end | ||
} | ||
%core_0_3 = amdaie.core(%tile_0_3, in : [], out : []) { | ||
%cst_0 = arith.constant 0.000000e+00 : f32 | ||
%alloc = memref.alloc() : memref<1024xf32, 2 : i32> | ||
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32> | ||
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>) | ||
memref.dealloc %alloc : memref<1024xf32, 2 : i32> | ||
%alloc_1 = memref.alloc() : memref<1024xf32, 2 : i32> | ||
%reinterpret_cast_1 = memref.reinterpret_cast %alloc_1 to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32> | ||
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast_1 : memref<1x1x8x8x4x4xf32, 2 : i32>) | ||
memref.dealloc %alloc_1 : memref<1024xf32, 2 : i32> | ||
amdaie.end | ||
} | ||
%core_1_2 = amdaie.core(%tile_1_2, in : [], out : []) { | ||
%cst_0 = arith.constant 0.000000e+00 : f32 | ||
%alloc = memref.alloc() : memref<1024xf32, 2 : i32> | ||
%alloc_1 = memref.alloc() : memref<1024xf32, 2 : i32> | ||
%reinterpret_cast = memref.reinterpret_cast %alloc to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32> | ||
%reinterpret_cast_1 = memref.reinterpret_cast %alloc_1 to offset: [0], sizes: [1, 1, 8, 8, 4, 4], strides: [1024, 1024, 128, 16, 4, 1] : memref<1024xf32, 2 : i32> to memref<1x1x8x8x4x4xf32, 2 : i32> | ||
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast : memref<1x1x8x8x4x4xf32, 2 : i32>) | ||
linalg.fill ins(%cst_0 : f32) outs(%reinterpret_cast_1 : memref<1x1x8x8x4x4xf32, 2 : i32>) | ||
memref.dealloc %alloc : memref<1024xf32, 2 : i32> | ||
memref.dealloc %alloc_1 : memref<1024xf32, 2 : i32> | ||
amdaie.end | ||
} | ||
amdaie.controlcode { | ||
amdaie.end | ||
} | ||
} | ||
return | ||
} |