Skip to content

Commit

Permalink
add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 10, 2024
1 parent c9616f4 commit be58fb4
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 107 deletions.
15 changes: 6 additions & 9 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,16 +1101,15 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
Value input, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides, Value bdId,
Value channel, BoolAttr useNextBd, Value nextBd,
Value startBd) {
Value channel, Value nextBd, Value startBd) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
build(b, result, resultTypes, connection, input, dynamicOffsets, dynamicSizes,
dynamicStrides, staticOffsets, staticSizes, staticStrides, bdId,
channel, useNextBd, nextBd, startBd);
channel, nextBd, startBd);
}

// Build a NpuHalfDmaCpyNdOp with static entries.
Expand All @@ -1119,8 +1118,7 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
Value input, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides, mlir::Value bdId,
Value channel, BoolAttr useNextBd, Value nextBd,
Value startBd) {
Value channel, Value nextBd, Value startBd) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(llvm::map_range(
offsets,
[&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); }));
Expand All @@ -1132,24 +1130,23 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
strides,
[&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); }));
build(b, result, resultTypes, connection, input, offsetValues, sizeValues,
strideValues, bdId, channel, useNextBd, nextBd, startBd);
strideValues, bdId, channel, nextBd, startBd);
}

// Build a NpuHalfDmaCpyNdOp with dynamic entries.
void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
TypeRange resultTypes, Value connection,
Value input, ValueRange offsets, ValueRange sizes,
ValueRange strides, mlir::Value bdId,
Value channel, BoolAttr useNextBd, Value nextBd,
Value startBd) {
Value channel, Value nextBd, Value startBd) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
build(b, result, resultTypes, connection, input, offsetValues, sizeValues,
strideValues, bdId, channel, useNextBd, nextBd, startBd);
strideValues, bdId, channel, nextBd, startBd);
}

std::optional<int64_t> NpuHalfDmaCpyNdOp::getStaticBaseOffset() {
Expand Down
25 changes: 11 additions & 14 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -591,17 +591,15 @@ def AMDAIE_NpuHalfDmaCpyNdOp
ShapedType::kDynamic encodes that the corresponding entry has a dynamic
value.

It also supports the representation of DMA BD chaining using the `use_next_bd`,
`next_bd`, and `start_bd` operands. The `use_next_bd` operand indicates
whether another DMA operation is chained to follow this one.
If `use_next_bd` is `true`, the `next_bd` operand specifies the BD ID of
the next DMA operation in the chain.
It also supports the representation of DMA BD chaining using the,
`next_bd`, and `start_bd` operands. The `next_bd` operand specifies
the BD ID of the next DMA operation in the chain, if there is any.

The `start_bd` operand specifies the BD ID of the first DMA operation in a sequence.
- If `start_bd` is the same as `bd_id`, it marks the start of a chain.
- If `start_bd` differs from `bd_id` and `use_next_bd` is `true`, it represents
- If `start_bd` differs from `bd_id` and `next_bd` is set, it represents
an intermediate operation in the chain.
- If `start_bd` differs from `bd_id` and `use_next_bd` is `false`, it represents
- If `start_bd` differs from `bd_id` and `next_bd` is not set, it represents
the end of the chain.

Example:
Expand All @@ -610,14 +608,15 @@ def AMDAIE_NpuHalfDmaCpyNdOp
%2 = amdaie.connection(%1, %0)
: (!amdaie.logicalobjectfifo<memref<32x64xi32, 1>>,
!amdaie.logicalobjectfifo<memref<32x1024xi32>>)
%bd_id = amdaie.bd_id(%tile_0_0, 0)
%bd_id_0 = amdaie.bd_id(%tile_0_0, 0)
%bd_id_1 = amdaie.bd_id(%tile_0_0, 1)
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
...
amdaie.controlcode {
%5 = amdaie.logicalobjectfifo.from_memref %0, {%tile_0_0}
: memref<32x1024xi32> -> !amdaie.logicalobjectfifo<memref<32768xi32>>
%4 = amdaie.npu.half_dma_cpy_nd async %2(%0[0, 0] [32, 64] [1024, 1]
bd_id = %bd_id channel = %channel use_next_bd = false start_bd = %bd_id)
bd_id = %bd_id_0 channel = %channel next_bd = %bd_id_1 start_bd = %bd_id_0)
...
}
```
Expand All @@ -634,7 +633,6 @@ def AMDAIE_NpuHalfDmaCpyNdOp
DenseI64ArrayAttr:$static_strides,
Optional<Index>:$bd_id,
Optional<Index>:$channel,
OptionalAttr<BoolAttr>:$use_next_bd,
Optional<Index>:$next_bd,
Optional<Index>:$start_bd
);
Expand All @@ -651,7 +649,6 @@ def AMDAIE_NpuHalfDmaCpyNdOp
custom<DynamicIndexList>($strides, $static_strides)
(`bd_id` `=` $bd_id^)?
(`channel` `=` $channel^)?
(`use_next_bd` `=` $use_next_bd^)?
(`next_bd` `=` $next_bd^)?
(`start_bd` `=` $start_bd^)?
`)`
Expand All @@ -665,18 +662,18 @@ def AMDAIE_NpuHalfDmaCpyNdOp
"::mlir::Value":$input, "ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
"::mlir::Value":$bd_id, "::mlir::Value":$channel,
"::mlir::BoolAttr":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
"::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
// Build a NpuHalfDmaCpyNdOp with static entries.
OpBuilder<(ins "::mlir::TypeRange":$result_types, "Value":$connection,
"::mlir::Value":$target, "ArrayRef<int64_t>":$offsets,
"ArrayRef<int64_t>":$sizes, "ArrayRef<int64_t>":$strides,
"::mlir::Value":$bd_id, "::mlir::Value":$channel,
"::mlir::BoolAttr":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
"::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
// Build a NpuHalfDmaCpyNdOp with dynamic entries.
OpBuilder<(ins "::mlir::TypeRange":$result_types, "Value":$connection,
"::mlir::Value":$input, "ValueRange":$offsets, "ValueRange":$sizes,
"ValueRange":$strides, "::mlir::Value":$bd_id, "::mlir::Value":$channel,
"::mlir::BoolAttr":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
"::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
];

let extraClassDeclaration = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,10 @@ func.func @npu_half_dma_cpy_nd(%arg0: !amdaie.logicalobjectfifo<memref<2048xi32>
amdaie.npu.half_dma_cpy_nd %0(%arg0[0] [1024] [1] bd_id = %bd_id) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [%[[C0]], 0] [%[[C0]], 64] [%[[C0]], 1] channel = %[[CHANNEL]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[%c0, 0] [%c0, 64] [%c0, 1] channel = %channel) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] bd_id = %[[BD_ID]] channel = %[[CHANNEL]] use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] bd_id = %bd_id channel = %channel use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] bd_id = %[[BD_ID]] channel = %[[CHANNEL]] use_next_bd = true next_bd = %[[BD_ID_1]] start_bd = %[[BD_ID]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] bd_id = %bd_id channel = %channel use_next_bd = true next_bd = %bd_id_1 start_bd = %bd_id) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] bd_id = %[[BD_ID]] channel = %[[CHANNEL]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] bd_id = %bd_id channel = %channel) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] bd_id = %[[BD_ID]] channel = %[[CHANNEL]] next_bd = %[[BD_ID_1]] start_bd = %[[BD_ID]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] bd_id = %bd_id channel = %channel next_bd = %bd_id_1 start_bd = %bd_id) : !amdaie.logicalobjectfifo<memref<2048xi32>>
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,12 @@ struct HalfDmaCpyNdToNpuConverter final
staticStrides.insert(staticStrides.begin(),
numIntraAddrDim - staticStrides.size(), 0);

bool useNextBd = op.getUseNextBd().value_or(false);
bool useNextBd = false;
int32_t nextBd{0};
if (useNextBd) {
std::optional<AMDAIE::BdIdOp> nextBdIdOp = op.getNextBdIdOp();
if (!nextBdIdOp) {
return op.emitOpError() << "useNextBd set, but no next BD ID op found";
}
std::optional<AMDAIE::BdIdOp> nextBdIdOp = op.getNextBdIdOp();
if (nextBdIdOp) {
nextBd = getConstantIndexOrAssert(nextBdIdOp.value().getValue());
useNextBd = true;
}

bool validBd{true};
Expand Down Expand Up @@ -216,9 +214,9 @@ struct HalfDmaCpyNdToNpuConverter final
if (failed(npuPushToQueueOp)) return failure();
rewriter.replaceOp(op, *npuPushToQueueOp);

bool useNextBd = op.getUseNextBd().value_or(false);
if (useNextBd) {
// `useNextBd` is true, so either at the beginning or middle of a chain.
std::optional<AMDAIE::BdIdOp> nextBdIdOp = op.getNextBdIdOp();
if (nextBdIdOp) {
// `next_bd` is set, so either at the beginning or middle of a chain.
// No need to push to the queue, just erase the op.
rewriter.eraseOp(*npuPushToQueueOp);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
op.getLoc(), resultTypeRange, op.getConnection(), op.getInput(),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides(),
op.getBdId(), op.getChannel(), op.getUseNextBdAttr(),
op.getNextBd(), op.getStartBd());
op.getBdId(), op.getChannel(), op.getNextBd(), op.getStartBd());
rewriter.eraseOp(op);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,50 @@ namespace {

using TileConnect = std::pair<AMDAIE::TileOp, AMDAIE::ConnectionOp>;

/// Utility function to update `use_next_bd`, `next_bd` and `start_bd` operands.
void updateChainOperands(IRRewriter &rewriter,
SmallVector<AMDAIE::NpuHalfDmaCpyNdOp> &dmaChain) {
if (dmaChain.size() < 2) return;
/// Utility function to update `next_bd` and `start_bd` operands.
LogicalResult updateChainOperands(
IRRewriter &rewriter, SmallVector<AMDAIE::NpuHalfDmaCpyNdOp> &dmaChain) {
// Nothing to do if the DMA chain length is one or less.
if (dmaChain.size() < 2) return success();

// Chain the DMA ops.
Value startBdId = dmaChain[0].getBdId();
Operation *parentOp = dmaChain[0]->getParentOp();
// Chain the DMA ops.
for (unsigned i = 0; i < dmaChain.size() - 1; ++i) {
AMDAIE::NpuHalfDmaCpyNdOp currDmaOp = dmaChain[i];
Value nextBd = dmaChain[i + 1].getBdId();
BoolAttr useNextBd = rewriter.getBoolAttr(true);
if (currDmaOp->getParentOp() != parentOp) {
return currDmaOp.emitError(
"DMA operations to be chained must belong to the same scope");
}
Value nextBdId = dmaChain[i + 1].getBdId();
// No token is produced at the beginning or middle of a chain.
TypeRange token = TypeRange{};
rewriter.setInsertionPointAfter(currDmaOp);
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
currDmaOp.getLoc(), token, currDmaOp.getConnection(),
currDmaOp.getInput(), currDmaOp.getMixedOffsets(),
currDmaOp.getMixedSizes(), currDmaOp.getMixedStrides(),
currDmaOp.getBdId(), currDmaOp.getChannel(), useNextBd, nextBd,
startBdId);
currDmaOp.getBdId(), currDmaOp.getChannel(), nextBdId, startBdId);
for (auto &use : currDmaOp->getUses()) {
rewriter.eraseOp(use.getOwner());
}
rewriter.eraseOp(currDmaOp);
}
// Last DMA op in the chain.
AMDAIE::NpuHalfDmaCpyNdOp lastDmaOp = dmaChain.back();
Value nextBd = nullptr;
BoolAttr useNextBd = rewriter.getBoolAttr(false);
if (lastDmaOp->getParentOp() != parentOp) {
return lastDmaOp.emitError(
"DMA operations to be chained must belong to the same scope");
}
Value nextBdId = nullptr;
rewriter.setInsertionPointAfter(lastDmaOp);
auto lastDmaOpChained = rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
lastDmaOp.getLoc(), lastDmaOp.getResultTypes(), lastDmaOp.getConnection(),
lastDmaOp.getInput(), lastDmaOp.getMixedOffsets(),
lastDmaOp.getMixedSizes(), lastDmaOp.getMixedStrides(),
lastDmaOp.getBdId(), lastDmaOp.getChannel(), useNextBd, nextBd,
startBdId);
lastDmaOp.getBdId(), lastDmaOp.getChannel(), nextBdId, startBdId);
rewriter.replaceOp(lastDmaOp, lastDmaOpChained.getResults());
return success();
}

/// Utility function to determine if chains can grow further
Expand Down Expand Up @@ -95,25 +102,27 @@ void canChainGrowFurther(
}
}

/// Traverse the control code in reverse order to create DMA BD chains.
/// Traverse the control code in reverse order to create DMA BD chains. Reverse
/// traversal simplifies handling duplicate BD IDs, preventing the need to
/// revisit and modify earlier operations after processing later ones.
LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,
AMDAIE::ControlCodeOp controlCodeOp) {
IRRewriter rewriter(controlCodeOp->getContext());

// Move all BdIdOps to the beginning of the control code.
// This is to avoid dominance issues when chaining BD IDs.
SmallVector<Operation *> ops;
SmallVector<Operation *> bdIdOps;
WalkResult res = controlCodeOp->walk([&](Operation *op) {
if (auto bdIdOp = dyn_cast<AMDAIE::BdIdOp>(op)) {
ops.push_back(op);
bdIdOps.push_back(op);
}
return WalkResult::advance();
});
for (Operation *op : llvm::reverse(ops)) {
for (Operation *op : llvm::reverse(bdIdOps)) {
op->moveBefore(&controlCodeOp.front());
}

// BD ID that are have been assigned in each tile.
// BD IDs that have been assigned in each tile.
DenseMap<TileConnect, SmallVector<uint32_t>> tileConnectToBdIds;
// Buffers the DMA ops that will be chained.
DenseMap<TileConnect, SmallVector<AMDAIE::NpuHalfDmaCpyNdOp>>
Expand All @@ -122,7 +131,7 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,
res = controlCodeOp->walk<WalkOrder::PostOrder, ReverseIterator>(
[&](Operation *op) {
if (auto npuHalfDmaCpyNdOp = dyn_cast<AMDAIE::NpuHalfDmaCpyNdOp>(op)) {
// Not shim, will be earsed at ControlcodeLowering, ignore.
// Not shim, will be erased at ControlcodeLowering, ignore.
if (npuHalfDmaCpyNdOp.getMemorySpaceAsUInt() != 0) {
return WalkResult::advance();
}
Expand Down Expand Up @@ -181,9 +190,9 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,
return WalkResult::interrupt();
}

// Any duplicate BD ID from the same tile indicates the chain cannot
// grow further and requires breaking to release the conflicting BD
// ID.
// Any duplicate BD ID from the same tile indicates that the chain
// cannot grow further and requires breaking to release the
// conflicting BD ID.
SmallVector<TileConnect> chainsToBreak;
TileConnect currTileConnect = {tileOp, connectionOp};
canChainGrowFurther(bdId, currTileConnect, tileConnectToBdIds,
Expand All @@ -193,7 +202,9 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,
// the `updateChainOperands` function.
if (!chainsToBreak.empty()) {
for (auto &entry : chainsToBreak) {
updateChainOperands(rewriter, tileConnectToDmaChain[entry]);
if (failed(updateChainOperands(rewriter,
tileConnectToDmaChain[entry])))
WalkResult::interrupt();
tileConnectToBdIds[entry].clear();
tileConnectToDmaChain[entry].clear();
}
Expand All @@ -211,7 +222,8 @@ LogicalResult insertDmaBdChain(const AMDAIE::AMDAIEDeviceModel &deviceModel,

// Build the remaining chains.
for (auto &[entry, _] : tileConnectToBdIds) {
updateChainOperands(rewriter, tileConnectToDmaChain[entry]);
if (failed(updateChainOperands(rewriter, tileConnectToDmaChain[entry])))
return failure();
}

if (res.wasInterrupted()) return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ struct NpuDmaToHalfDmaCpyNdConverter final
return dmaOp.emitOpError()
<< "should operate on an `amdaie.connection` op";
}
BoolAttr useNextBd = rewriter.getBoolAttr(false);
Value nextBd{nullptr};
// Convert source half.
Value source =
Expand All @@ -53,7 +52,7 @@ struct NpuDmaToHalfDmaCpyNdConverter final
dmaOp.getLoc(), sourceResultTypes, connectionOp, source,
dmaOp.getSourceMixedOffsets(), dmaOp.getSourceMixedSizes(),
dmaOp.getSourceMixedStrides(), dmaOp.getSourceBdId(), sourceChannelOp,
useNextBd, nextBd, dmaOp.getSourceBdId());
nextBd, dmaOp.getSourceBdId());

// Convert target half.
Value target =
Expand All @@ -72,7 +71,7 @@ struct NpuDmaToHalfDmaCpyNdConverter final
dmaOp.getLoc(), targetResultTypes, connectionOp, target,
dmaOp.getTargetMixedOffsets(), dmaOp.getTargetMixedSizes(),
dmaOp.getTargetMixedStrides(), dmaOp.getTargetBdId(), targetChannelOp,
useNextBd, nextBd, dmaOp.getTargetBdId());
nextBd, dmaOp.getTargetBdId());
if (dmaOp.getNumResults() == 1) {
if (sourceDma.getNumResults() == 1) {
rewriter.replaceUsesWithIf(
Expand Down
Loading

0 comments on commit be58fb4

Please sign in to comment.