Skip to content

Commit

Permalink
[LowerToAIE] Don't force change memspace (#692)
Browse files Browse the repository at this point in the history
This constraint the current convolution pipeline, where there are
further memref.subviews which end up with different memspaces on src and
dst.

Removing this constraint doesn't seem to make anything fail, so I'm just
removing it.
  • Loading branch information
newling authored Aug 26, 2024
1 parent 652d648 commit cb2114c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,14 @@ FailureOr<AIE::ObjectFifoCreateOp> createObjectFifo(
1, std::multiplies<>());
int64_t targetSize = std::accumulate(targetShape.begin(), targetShape.end(),
1, std::multiplies<>());
// TODO(jornt) for now, memory space 1 is used for objectfifos. Maybe refactor
// `aie.objectfifo` in the future to support different memory spaces.
MemRefType memrefType =
sourceSize < targetSize
? MemRefType::get({sourceSize}, srcType.getElementType(),
MemRefLayoutAttrInterface{},
rewriter.getI64IntegerAttr(1))
srcType.getMemorySpace())
: MemRefType::get({targetSize}, dstType.getElementType(),
MemRefLayoutAttrInterface{},
rewriter.getI64IntegerAttr(1));
dstType.getMemorySpace());
AIE::AIEObjectFifoType dtype = AIE::AIEObjectFifoType::get(memrefType);
auto fifo = rewriter.create<AIE::ObjectFifoCreateOp>(
rewriter.getUnknownLoc(), symName, srcTile, dstTiles,
Expand Down Expand Up @@ -204,10 +202,9 @@ LogicalResult accessOpToAIE(IRRewriter &rewriter,
}

auto type = cast<MemRefType>(oldReinterpretOp.getResult().getType());
// TODO(jornt): for now, memory space 1 is used for objectFifos. Refactor
// `aie.objectfifo` to support different memory spaces to avoid hardcoding.
MemRefType newType =
MemRefType::Builder(type).setMemorySpace(rewriter.getI64IntegerAttr(1));

MemRefType newType = MemRefType::Builder(type);

llvm::ArrayRef<int64_t> sizes = newType.getShape();
auto [strides, baseOffset] = getStridesAndOffset(newType);
auto reinterpretOp = rewriter.create<memref::ReinterpretCastOp>(
Expand All @@ -229,6 +226,7 @@ LogicalResult acquireOpToAIE(IRRewriter &rewriter,
IRMapping &mapper,
SmallVector<Operation *> &toBeErased) {
LLVM_DEBUG(llvm::dbgs() << "Convert [AMDAIE::LogicalObjectFifoAcquire]\n");

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(acquireOp);
auto dmaOp =
Expand All @@ -244,20 +242,25 @@ LogicalResult acquireOpToAIE(IRRewriter &rewriter,
return acquireOp.emitError()
<< "input isn't mapped to an `aie.objectifo` operation";
}
AIE::AIEObjectFifoType ofTy =
cast<AIE::AIEObjectFifoType>(objFifo.getElemType());
MemRefType elementType = MemRefType::Builder(ofTy.getElementType())
.setMemorySpace(rewriter.getI64IntegerAttr(1));

auto acquireOpType = dyn_cast<LogicalObjectFifoType>(acquireOp.getType());
assert(acquireOpType &&
"Expected LogicalObjectFifoAcquire to have type "
"LogicalObjectFifoType");
MemRefType elementType = acquireOpType.getElementType();

auto subviewType = AIE::AIEObjectFifoSubviewType::get(elementType);
AIE::ObjectFifoPort port =
acquireOp.getPort() == LogicalObjectFifoPort::Produce
? AIE::ObjectFifoPort::Produce
: AIE::ObjectFifoPort::Consume;
auto objFifoAquireOp = rewriter.create<AIE::ObjectFifoAcquireOp>(
rewriter.getUnknownLoc(), subviewType, port, objFifo.getName(), 1);

auto subviewOp = rewriter.create<AIE::ObjectFifoSubviewAccessOp>(
rewriter.getUnknownLoc(), elementType, objFifoAquireOp.getSubview(),
rewriter.getIntegerAttr(rewriter.getI32Type(), 0));

// Map acquire op to new acquire + subview op.
mapper.map(acquireOp.getOperation(), subviewOp.getOperation());
mapper.map(acquireOp.getResult(), subviewOp.getOutput());
Expand Down Expand Up @@ -1008,7 +1011,7 @@ class AMDAIELowerToAIEPass
}

AMDAIELowerToAIEPass() = default;
AMDAIELowerToAIEPass(const AMDAIELowerToAIEPass &pass) {};
AMDAIELowerToAIEPass(const AMDAIELowerToAIEPass &pass){};
void runOnOperation() override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// -----

// CHECK: aie.device
// CHECK-DAG: func.func private @ukernel_A(memref<i32, 1>, index) attributes {llvm.bareptr = true}
// CHECK-DAG: func.func private @ukernel_B(memref<i32, 1>, index, memref<f32, 1>, index) attributes {llvm.bareptr = true}
// CHECK-DAG: func.func private @ukernel_A(memref<i32, 2>, index) attributes {llvm.bareptr = true}
// CHECK-DAG: func.func private @ukernel_B(memref<i32, 2>, index, memref<f32, 2>, index) attributes {llvm.bareptr = true}
// CHECK-DAG: %[[TILE_0_2:.+]] = aie.tile(0, 2)
// CHECK: aie.core(%[[TILE_0_2]])
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
Expand All @@ -233,11 +233,11 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK-SAME: Produce
// CHECK: %[[ACCESS0:.+]] = aie.objectfifo.subview.access %[[ACQUIRE0]]
// CHECK: %[[REINTERPRET0:.+]] = memref.reinterpret_cast %[[ACCESS0]]
// CHECK: linalg.fill ins(%{{.+}} : i32) outs(%[[REINTERPRET]] : memref<32x32xi32, 1>)
// CHECK: linalg.fill ins(%{{.+}} : i32) outs(%[[REINTERPRET]] : memref<32x32xi32, 2>)
// CHECK: %[[BASE_BUFFER:.*]], %{{.+}}, %{{.+}}:2, %{{.+}}:2 = memref.extract_strided_metadata %[[REINTERPRET]] :
// CHECK: %[[BASE_BUFFER0:.*]], %{{.+}}, %{{.+}}:2, %{{.+}}:2 = memref.extract_strided_metadata %[[REINTERPRET0]] :
// CHECK: func.call @ukernel_A(%[[BASE_BUFFER]], %[[C0]]) : (memref<i32, 1>, index) -> ()
// CHECK: func.call @ukernel_B(%[[BASE_BUFFER]], %[[C0]], %[[BASE_BUFFER0]], %[[C0]]) : (memref<i32, 1>, index, memref<f32, 1>, index) -> ()
// CHECK: func.call @ukernel_A(%[[BASE_BUFFER]], %[[C0]]) : (memref<i32, 2>, index) -> ()
// CHECK: func.call @ukernel_B(%[[BASE_BUFFER]], %[[C0]], %[[BASE_BUFFER0]], %[[C0]]) : (memref<i32, 2>, index, memref<f32, 2>, index) -> ()
// CHECK: aie.end
// CHECK: } {link_with = "/path/to/ukernel.o"}
// CHECK: aiex.runtime_sequence @lower_to_aie_ukernel
Expand Down Expand Up @@ -738,10 +738,10 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// CHECK: aie.device(npu1_4col) {
// CHECK: %[[TILE_0_0:.*]] = aie.tile(0, 0)
// CHECK: %[[TILE_0_1:.*]] = aie.tile(0, 1)
// CHECK: aie.objectfifo @[[OBJ0:.*]](%[[TILE_0_0]], {%[[TILE_0_1]]}, 2 : i32) : !aie.objectfifo<memref<1024xbf16, 1>>
// CHECK: aie.objectfifo @[[OBJ1:.*]](%[[TILE_0_0]], {%[[TILE_0_1]]}, 2 : i32) : !aie.objectfifo<memref<1024xbf16, 1>>
// CHECK: aie.objectfifo @[[OBJ0:.*]](%[[TILE_0_0]], {%[[TILE_0_1]]}, 2 : i32) : !aie.objectfifo<memref<1024xbf16, 1 : i32>>
// CHECK: aie.objectfifo @[[OBJ1:.*]](%[[TILE_0_0]], {%[[TILE_0_1]]}, 2 : i32) : !aie.objectfifo<memref<1024xbf16, 1 : i32>>
// CHECK: aie.objectfifo @[[OBJ2:.*]](%[[TILE_0_1]]
// CHECK-SAME: %[[TILE_0_0]]}, 2 : i32) : !aie.objectfifo<memref<1024xf32, 1>>
// CHECK-SAME: %[[TILE_0_0]]}, 2 : i32) : !aie.objectfifo<memref<1024xf32>>
// CHECK: aiex.runtime_sequence @bf16_f32_lit_test
// CHECK-SAME: (%[[LHS:.*]]: memref<32x32xbf16>, %[[RHS:.*]]: memref<32x32xbf16>, %[[OUT:.*]]: memref<32x32xf32>) {
// CHECK: aiex.npu.dma_memcpy_nd
Expand Down

0 comments on commit cb2114c

Please sign in to comment.