Skip to content

Commit

Permalink
Add a new pass to create DMA chains inside controlcode (#974)
Browse files Browse the repository at this point in the history
Adds support for creating chains on `half_dma_cpy_nd`. After further
lowering, the entire chain can share one `push_to_queue` and `dma_wait`
ops. Refactored from #931
  • Loading branch information
Yu-Zhewen authored Dec 11, 2024
1 parent db10c75 commit 362f041
Show file tree
Hide file tree
Showing 14 changed files with 686 additions and 14 deletions.
12 changes: 6 additions & 6 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,15 +1123,15 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
Value input, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides, Value bdId,
Value channel) {
Value channel, Value nextBd, Value startBd) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
build(b, result, resultTypes, connection, input, dynamicOffsets, dynamicSizes,
dynamicStrides, staticOffsets, staticSizes, staticStrides, bdId,
channel);
channel, nextBd, startBd);
}

// Build a NpuHalfDmaCpyNdOp with static entries.
Expand All @@ -1140,7 +1140,7 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
Value input, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides, mlir::Value bdId,
Value channel) {
Value channel, Value nextBd, Value startBd) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(llvm::map_range(
offsets,
[&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); }));
Expand All @@ -1152,23 +1152,23 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
strides,
[&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); }));
build(b, result, resultTypes, connection, input, offsetValues, sizeValues,
strideValues, bdId, channel);
strideValues, bdId, channel, nextBd, startBd);
}

// Build a NpuHalfDmaCpyNdOp with dynamic entries.
void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
TypeRange resultTypes, Value connection,
Value input, ValueRange offsets, ValueRange sizes,
ValueRange strides, mlir::Value bdId,
Value channel) {
Value channel, Value nextBd, Value startBd) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
build(b, result, resultTypes, connection, input, offsetValues, sizeValues,
strideValues, bdId, channel);
strideValues, bdId, channel, nextBd, startBd);
}

std::optional<int64_t> NpuHalfDmaCpyNdOp::getStaticBaseOffset() {
Expand Down
42 changes: 36 additions & 6 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -595,20 +595,32 @@ def AMDAIE_NpuHalfDmaCpyNdOp
ShapedType::kDynamic encodes that the corresponding entry has a dynamic
value.

It also supports the representation of DMA BD chaining using the,
`next_bd`, and `start_bd` operands. The `next_bd` operand specifies
the BD ID of the next DMA operation in the chain, if there is any.

The `start_bd` operand specifies the BD ID of the first DMA operation in a sequence.
- If `start_bd` is the same as `bd_id`, it marks the start of a chain.
- If `start_bd` differs from `bd_id` and `next_bd` is set, it represents
an intermediate operation in the chain.
- If `start_bd` differs from `bd_id` and `next_bd` is not set, it represents
the end of the chain.

Example:

```mlir
%2 = amdaie.connection(%1, %0)
: (!amdaie.logicalobjectfifo<memref<32x64xi32, 1>>,
!amdaie.logicalobjectfifo<memref<32x1024xi32>>)
%bd_id = amdaie.bd_id(%tile_0_0, 0)
%bd_id_0 = amdaie.bd_id(%tile_0_0, 0)
%bd_id_1 = amdaie.bd_id(%tile_0_0, 1)
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
...
amdaie.controlcode {
%5 = amdaie.logicalobjectfifo.from_memref %0, {%tile_0_0}
: memref<32x1024xi32> -> !amdaie.logicalobjectfifo<memref<32768xi32>>
%4 = amdaie.npu.half_dma_cpy_nd async %2(%0[0, 0] [32, 64] [1024, 1]
bd_id = %bd_id channel = %channel)
bd_id = %bd_id_0 channel = %channel next_bd = %bd_id_1 start_bd = %bd_id_0)
...
}
```
Expand All @@ -624,7 +636,9 @@ def AMDAIE_NpuHalfDmaCpyNdOp
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides,
Optional<Index>:$bd_id,
Optional<Index>:$channel
Optional<Index>:$channel,
Optional<Index>:$next_bd,
Optional<Index>:$start_bd
);

let results = (outs Optional<AMDAIE_AsyncTokenType>:$async_token);
Expand All @@ -639,6 +653,8 @@ def AMDAIE_NpuHalfDmaCpyNdOp
custom<DynamicIndexList>($strides, $static_strides)
(`bd_id` `=` $bd_id^)?
(`channel` `=` $channel^)?
(`next_bd` `=` $next_bd^)?
(`start_bd` `=` $start_bd^)?
`)`
attr-dict
`:` type($input)
Expand All @@ -649,16 +665,19 @@ def AMDAIE_NpuHalfDmaCpyNdOp
OpBuilder<(ins "::mlir::TypeRange":$result_types, "Value":$connection,
"::mlir::Value":$input, "ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
"::mlir::Value":$bd_id, "::mlir::Value":$channel)>,
"::mlir::Value":$bd_id, "::mlir::Value":$channel,
CArg<"::mlir::Value", "nullptr">:$next_bd, CArg<"::mlir::Value", "nullptr">:$start_bd)>,
// Build a NpuHalfDmaCpyNdOp with static entries.
OpBuilder<(ins "::mlir::TypeRange":$result_types, "Value":$connection,
"::mlir::Value":$target, "ArrayRef<int64_t>":$offsets,
"ArrayRef<int64_t>":$sizes, "ArrayRef<int64_t>":$strides,
"::mlir::Value":$bd_id, "::mlir::Value":$channel)>,
"::mlir::Value":$bd_id, "::mlir::Value":$channel,
CArg<"::mlir::Value", "nullptr">:$next_bd, CArg<"::mlir::Value", "nullptr">:$start_bd)>,
// Build a NpuHalfDmaCpyNdOp with dynamic entries.
OpBuilder<(ins "::mlir::TypeRange":$result_types, "Value":$connection,
"::mlir::Value":$input, "ValueRange":$offsets, "ValueRange":$sizes,
"ValueRange":$strides, "::mlir::Value":$bd_id, "::mlir::Value":$channel)>
"ValueRange":$strides, "::mlir::Value":$bd_id, "::mlir::Value":$channel,
CArg<"::mlir::Value", "nullptr">:$next_bd, CArg<"::mlir::Value", "nullptr">:$start_bd)>,
];

let extraClassDeclaration = [{
Expand All @@ -674,9 +693,20 @@ def AMDAIE_NpuHalfDmaCpyNdOp
}

std::optional<BdIdOp> getBdIdOp() {
if (!getBdId()) return std::nullopt;
return dyn_cast_if_present<BdIdOp>(getBdId().getDefiningOp());
}

std::optional<BdIdOp> getNextBdIdOp() {
if (!getNextBd()) return std::nullopt;
return dyn_cast_if_present<BdIdOp>(getNextBd().getDefiningOp());
}

std::optional<BdIdOp> getStartBdIdOp() {
if (!getStartBd()) return std::nullopt;
return dyn_cast_if_present<BdIdOp>(getStartBd().getDefiningOp());
}

// Return the input `amdaie.connection` operation.
std::optional<ConnectionOp> getConnectionOp() {
return dyn_cast_if_present<ConnectionOp>(getConnection().getDefiningOp());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,15 @@ func.func @npu_dma_cpy_nd_all_operands(%arg0: !amdaie.logicalobjectfifo<memref<1
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK-DAG: %[[BD_ID:.+]] = amdaie.bd_id(%[[TILE_0_0]], %[[C0]])
// CHECK-DAG: %[[BD_ID_1:.+]] = amdaie.bd_id(%[[TILE_0_0]], %[[C1]])
// CHECK-DAG: %[[CHANNEL:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = DMA, direction = S2MM)
// CHECK-DAG: %[[CONNECTION_0:.+]] = amdaie.connection
func.func @npu_half_dma_cpy_nd(%arg0: !amdaie.logicalobjectfifo<memref<2048xi32>>, %arg1: !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%tile_0_0 = amdaie.tile(%c0, %c0)
%bd_id = amdaie.bd_id(%tile_0_0, %c0)
%bd_id_1 = amdaie.bd_id(%tile_0_0, %c1)
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = S2MM)
%0 = amdaie.connection(%arg0, %arg1) : (!amdaie.logicalobjectfifo<memref<2048xi32>>, !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>)
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] []) : !amdaie.logicalobjectfifo<memref<2048xi32>>
Expand All @@ -416,6 +418,8 @@ func.func @npu_half_dma_cpy_nd(%arg0: !amdaie.logicalobjectfifo<memref<2048xi32>
amdaie.npu.half_dma_cpy_nd %0(%arg0[%c0, 0] [%c0, 64] [%c0, 1] channel = %channel) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] bd_id = %[[BD_ID]] channel = %[[CHANNEL]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] bd_id = %bd_id channel = %channel) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] bd_id = %[[BD_ID]] channel = %[[CHANNEL]] next_bd = %[[BD_ID_1]] start_bd = %[[BD_ID]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] bd_id = %bd_id channel = %channel next_bd = %bd_id_1 start_bd = %bd_id) : !amdaie.logicalobjectfifo<memref<2048xi32>>
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,13 @@ struct HalfDmaCpyNdToNpuConverter final
staticStrides.insert(staticStrides.begin(),
numIntraAddrDim - staticStrides.size(), 0);

bool useNextBd{false};
bool useNextBd = false;
int32_t nextBd{0};
if (std::optional<AMDAIE::BdIdOp> nextBdIdOp = op.getNextBdIdOp()) {
nextBd = getConstantIndexOrAssert(nextBdIdOp.value().getValue());
useNextBd = true;
}

bool validBd{true};
int32_t lockRelVal{0};
int32_t lockRelId{0};
Expand Down Expand Up @@ -208,6 +213,21 @@ struct HalfDmaCpyNdToNpuConverter final
strides);
if (failed(npuPushToQueueOp)) return failure();
rewriter.replaceOp(op, *npuPushToQueueOp);

std::optional<AMDAIE::BdIdOp> nextBdIdOp = op.getNextBdIdOp();
if (nextBdIdOp) {
// `next_bd` is set, so either at the beginning or middle of a chain.
// No need to push to the queue, just erase the op.
rewriter.eraseOp(*npuPushToQueueOp);
} else {
std::optional<AMDAIE::BdIdOp> maybeStartBdIdOp = op.getStartBdIdOp();
if (maybeStartBdIdOp) {
// Update with the BD ID at the start of the chain.
AMDAIE::BdIdOp startBdIdOp = maybeStartBdIdOp.value();
uint32_t startBdId = getConstantIndexOrAssert(startBdIdOp.getValue());
npuPushToQueueOp->setBdId(startBdId);
}
}
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ LogicalResult foldDmaWaits(const AMDAIE::AMDAIEDeviceModel &deviceModel,
rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
op.getLoc(), resultTypeRange, op.getConnection(), op.getInput(),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides(),
op.getBdId(), op.getChannel());
op.getBdId(), op.getChannel(), op.getNextBd(), op.getStartBd());
rewriter.eraseOp(op);
}
}
Expand Down
Loading

0 comments on commit 362f041

Please sign in to comment.