Skip to content

Commit

Permalink
simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 9, 2024
1 parent 2230b16 commit 82f87b3
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 398 deletions.
6 changes: 3 additions & 3 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
Value input, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides, Value bdId,
Value channel, bool useNextBd, Value nextBd,
Value channel, BoolAttr useNextBd, Value nextBd,
Value startBd) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
Expand All @@ -1119,7 +1119,7 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
Value input, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides, mlir::Value bdId,
Value channel, bool useNextBd, Value nextBd,
Value channel, BoolAttr useNextBd, Value nextBd,
Value startBd) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(llvm::map_range(
offsets,
Expand All @@ -1140,7 +1140,7 @@ void NpuHalfDmaCpyNdOp::build(OpBuilder &b, OperationState &result,
TypeRange resultTypes, Value connection,
Value input, ValueRange offsets, ValueRange sizes,
ValueRange strides, mlir::Value bdId,
Value channel, bool useNextBd, Value nextBd,
Value channel, BoolAttr useNextBd, Value nextBd,
Value startBd) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
Expand Down
24 changes: 15 additions & 9 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,15 @@ def AMDAIE_NpuHalfDmaCpyNdOp
`next_bd`, and `start_bd` operands. The `use_next_bd` operand indicates
whether another DMA operation is chained to follow this one.
If `use_next_bd` is `true`, the `next_bd` operand specifies the BD ID of
the next DMA operation in the chain. Within a chain, the `start_bd` operand
identifies the BD ID of the first DMA operation in the sequence.
When `use_next_bd` is `false`, the `start_bd` is set to the same value as `bd_id`.

the next DMA operation in the chain.

The `start_bd` operand specifies the BD ID of the first DMA operation in a sequence.
- If `start_bd` is the same as `bd_id`, it marks the start of a chain.
- If `start_bd` differs from `bd_id` and `use_next_bd` is `true`, it represents
an intermediate operation in the chain.
- If `start_bd` differs from `bd_id` and `use_next_bd` is `false`, it represents
the end of the chain.

Example:

```mlir
Expand Down Expand Up @@ -629,7 +634,7 @@ def AMDAIE_NpuHalfDmaCpyNdOp
DenseI64ArrayAttr:$static_strides,
Optional<Index>:$bd_id,
Optional<Index>:$channel,
BoolAttr:$use_next_bd,
OptionalAttr<BoolAttr>:$use_next_bd,
Optional<Index>:$next_bd,
Optional<Index>:$start_bd
);
Expand All @@ -646,7 +651,7 @@ def AMDAIE_NpuHalfDmaCpyNdOp
custom<DynamicIndexList>($strides, $static_strides)
(`bd_id` `=` $bd_id^)?
(`channel` `=` $channel^)?
`use_next_bd` `=` $use_next_bd
(`use_next_bd` `=` $use_next_bd^)?
(`next_bd` `=` $next_bd^)?
(`start_bd` `=` $start_bd^)?
`)`
Expand All @@ -660,18 +665,18 @@ def AMDAIE_NpuHalfDmaCpyNdOp
"::mlir::Value":$input, "ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
"::mlir::Value":$bd_id, "::mlir::Value":$channel,
"bool":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
"::mlir::BoolAttr":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
// Build a NpuHalfDmaCpyNdOp with static entries.
OpBuilder<(ins "::mlir::TypeRange":$result_types, "Value":$connection,
"::mlir::Value":$target, "ArrayRef<int64_t>":$offsets,
"ArrayRef<int64_t>":$sizes, "ArrayRef<int64_t>":$strides,
"::mlir::Value":$bd_id, "::mlir::Value":$channel,
"bool":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
"::mlir::BoolAttr":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
// Build a NpuHalfDmaCpyNdOp with dynamic entries.
OpBuilder<(ins "::mlir::TypeRange":$result_types, "Value":$connection,
"::mlir::Value":$input, "ValueRange":$offsets, "ValueRange":$sizes,
"ValueRange":$strides, "::mlir::Value":$bd_id, "::mlir::Value":$channel,
"bool":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
"::mlir::BoolAttr":$use_next_bd, "::mlir::Value":$next_bd, "::mlir::Value":$start_bd)>,
];

let extraClassDeclaration = [{
Expand All @@ -687,6 +692,7 @@ def AMDAIE_NpuHalfDmaCpyNdOp
}

std::optional<BdIdOp> getBdIdOp() {
if (!getBdId()) return std::nullopt;
return dyn_cast_if_present<BdIdOp>(getBdId().getDefiningOp());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,25 +397,29 @@ func.func @npu_dma_cpy_nd_all_operands(%arg0: !amdaie.logicalobjectfifo<memref<1
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[TILE_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK-DAG: %[[BD_ID:.+]] = amdaie.bd_id(%[[TILE_0_0]], %[[C0]])
// CHECK-DAG: %[[BD_ID_1:.+]] = amdaie.bd_id(%[[TILE_0_0]], %[[C1]])
// CHECK-DAG: %[[CHANNEL:.*]] = amdaie.channel(%[[TILE_0_0]], 0, port_type = DMA, direction = S2MM)
// CHECK-DAG: %[[CONNECTION_0:.+]] = amdaie.connection
func.func @npu_half_dma_cpy_nd(%arg0: !amdaie.logicalobjectfifo<memref<2048xi32>>, %arg1: !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%tile_0_0 = amdaie.tile(%c0, %c0)
%bd_id = amdaie.bd_id(%tile_0_0, %c0)
%bd_id_1 = amdaie.bd_id(%tile_0_0, %c1)
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = S2MM)
%0 = amdaie.connection(%arg0, %arg1) : (!amdaie.logicalobjectfifo<memref<2048xi32>>, !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>)
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: %{{.+}} = amdaie.npu.half_dma_cpy_nd async %[[CONNECTION_0]](%[[ARG0]] [] [] [] use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd async %0(%arg0[] [] [] use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [0] [1024] [1] bd_id = %[[BD_ID]] use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[0] [1024] [1] bd_id = %bd_id use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [%[[C0]], 0] [%[[C0]], 64] [%[[C0]], 1] channel = %[[CHANNEL]] use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[%c0, 0] [%c0, 64] [%c0, 1] channel = %channel use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] []) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] []) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: %{{.+}} = amdaie.npu.half_dma_cpy_nd async %[[CONNECTION_0]](%[[ARG0]] [] [] []) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd async %0(%arg0[] [] []) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [0] [1024] [1] bd_id = %[[BD_ID]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[0] [1024] [1] bd_id = %bd_id) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [%[[C0]], 0] [%[[C0]], 64] [%[[C0]], 1] channel = %[[CHANNEL]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[%c0, 0] [%c0, 64] [%c0, 1] channel = %channel) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] bd_id = %[[BD_ID]] channel = %[[CHANNEL]] use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] bd_id = %bd_id channel = %channel use_next_bd = false) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.half_dma_cpy_nd %[[CONNECTION_0]](%[[ARG0]] [] [] [] bd_id = %[[BD_ID]] channel = %[[CHANNEL]] use_next_bd = true next_bd = %[[BD_ID_1]] start_bd = %[[BD_ID]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
amdaie.npu.half_dma_cpy_nd %0(%arg0[] [] [] bd_id = %bd_id channel = %channel use_next_bd = true next_bd = %bd_id_1 start_bd = %bd_id) : !amdaie.logicalobjectfifo<memref<2048xi32>>
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct HalfDmaCpyNdToNpuConverter final
staticStrides.insert(staticStrides.begin(),
numIntraAddrDim - staticStrides.size(), 0);

bool useNextBd = op.getUseNextBd();
bool useNextBd = op.getUseNextBd().value_or(false);
int32_t nextBd{0};
if (useNextBd) {
std::optional<AMDAIE::BdIdOp> nextBdIdOp = op.getNextBdIdOp();
Expand Down Expand Up @@ -216,19 +216,18 @@ struct HalfDmaCpyNdToNpuConverter final
if (failed(npuPushToQueueOp)) return failure();
rewriter.replaceOp(op, *npuPushToQueueOp);

bool useNextBd = op.getUseNextBd();
if (useNextBd)
// Erase if not end of chain.
bool useNextBd = op.getUseNextBd().value_or(false);
if (useNextBd) {
// `useNextBd` is true, so either at the beginning or middle of a chain.
// No need to push to the queue, just erase the op.
rewriter.eraseOp(*npuPushToQueueOp);
else {
} else {
std::optional<AMDAIE::BdIdOp> maybeStartBdIdOp = op.getStartBdIdOp();
if (maybeStartBdIdOp) {
// Update the BD ID with the start of the chain.
uint32_t startBdId =
getConstantIndexOrAssert(maybeStartBdIdOp.value().getValue());
uint32_t bdId =
getConstantIndexOrAssert(maybeBdIdOp.value().getValue());
if (startBdId != bdId) npuPushToQueueOp->setBdId(startBdId);
// Update with the BD ID at the start of the chain.
AMDAIE::BdIdOp startBdIdOp = maybeStartBdIdOp.value();
uint32_t startBdId = getConstantIndexOrAssert(startBdIdOp.getValue());
npuPushToQueueOp->setBdId(startBdId);
}
}
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ void AMDAIEDmaCompositionPass::runOnOperation() {
"after strided op composition";
return signalPassFailure();
}

if (failed(moveNpuSourceDmaSyncAfterTargetDmaCpy(rewriter, parentOp))) {
parentOp->emitOpError()
<< "failed to move source DMA sync after target DMA copy";
return signalPassFailure();
}
}

} // namespace
Expand Down
Loading

0 comments on commit 82f87b3

Please sign in to comment.