-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a pass to fold DMA waits #962
Conversation
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESimplifyDmaWaits.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESimplifyDmaWaits.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESimplifyDmaWaits.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESimplifyDmaWaits.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESimplifyDmaWaits.cpp
Outdated
Show resolved
Hide resolved
} | ||
} | ||
|
||
// ----- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add comments about what is being tested in each case. I can see there are fewer dma_wait operations after the pass, but it's not clear to me which ones are being removed. Also a bit surprised that there are no CHECK-NOT or CHECK-NEXT statements.
@@ -378,6 +378,8 @@ struct AMDAIEDeviceModel { | |||
DenseMap<uint32_t, SmallVector<uint32_t>> getChannelToValidBdIds( | |||
AMDAIETileType tileType) const; | |||
|
|||
uint8_t getDmaMaxQueueSize(uint8_t col, uint8_t row); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really unfortunate that this can't be const
. I see that getTileType uses a const_cast to workaround the lack of const-correctness in aie_rt, seems like first done here
6c4f905#diff-17008229092a63d5df9831105108aa99c799a226441b0dd9d8327708e57fc2aeR218
I tracked this down because I saw you passing AMDAIE::AMDAIEDeviceModel deviceModel
by value, which isn't great as this isn't a pointer type so there is overhead. But if getDmaMaxQueueSize isn't const, deviceModel can't be const ref where you use it...
<< "expected to operate on an `amdaie.flow`"; | ||
return WalkResult::interrupt(); | ||
} | ||
if (maybeFlowOp->getIsPacketFlow()) return WalkResult::advance(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This advance -- so you skip to the next waitOp, effectively making toErase = false for this waitOp? I think the logic would be easier to follow if you made standalone functions. Maybe a function like
LogicalResult canFoldWaitOp(WaitOp waitOp, AMDAIE::AMDAIEDeviceModel deviceModel, ...) {
}
// CHECK: %[[TOKEN_0:.+]] = amdaie.npu.half_dma_cpy_nd async %[[CONNECTION]](%[[OBJECT_FIFO_1]] [] [] [] bd_id = %[[BD_ID_0]] channel = %[[CHANNEL_0]]) : !amdaie.logicalobjectfifo<memref<2048xi32>> | ||
// CHECK: amdaie.npu.dma_wait(%[[TOKEN_0]] : !amdaie.async_token) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have a wait on the first dma_cpy_nd
?
I think we should expect:
dma_cpy_nd
dma_cpy_nd
dma_cpy_nd
%0 = dma_cpy_nd
dma_wait(%0)
Instead of:
%0 = dma_cpy_nd
dma_wait(%0)
dma_cpy_nd
dma_cpy_nd
%1 = dma_cpy_nd
dma_wait(%0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am traversing controlcode in reverse order, and the example actually is:
%0 = dma_cpy_nd
dma_wait(%0)
dma_cpy_nd
dma_cpy_nd
dma_cpy_nd
%1 = dma_cpy_nd
dma_wait(%1)
there are four dma_cpy_nd
ops between.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I missed the fourth op between the waits, but we still shouldn't have a wait on the first op. This does matter if the number of dma_cpy_nd ops is smaller or equal to 4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, not only for nb_dma_cpy_nd_ops <= 4
, but also in the example above, you're using two waits in what could be implemented with one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but if we don't have a wait on the first op, then we will have to do this:
dma_cpy_nd
dma_cpy_nd
dma_cpy_nd
%0 = dma_cpy_nd
dma_wait(%0)
%1 = dma_cpy_nd
dma_wait(%1)
we still have two waits, since anyway we need one wait at the end of controlcode? That's why I choose traverse in reverse order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, I need to check if the current wait is the last one for each connection
You can traverse in forward order and keep track of every 4th and last DMA on a connection, then in a second pass, you can only keep those waits. I don't see how that's more complex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I think this is important is that the output IR should be as one would intuitively expect. A lot of the time I have to debug issues by just reading and understanding the IR and non-intuitive output is no fun when you're doing that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried something, but then I realized that traversing in forward order complicates the management of BD IDs. When encountering a duplicate BD ID, we need to keep the wait for the last DMA whichused that BD ID. Also, all the erasion decisions between those two DMAs need to be updated accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, could you add documentation to the function on why you iterate in reverse order and an example to show expected output for one of these more quirky cases above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, added now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments in the lit tests, they helped me
AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value(); | ||
// Retrieve the flow op. | ||
std::optional<AMDAIE::FlowOp> maybeFlowOp = | ||
maybeConnectionOp->getFlowOp(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybeConnectionOp->getFlowOp(); | |
connectionOp.getFlowOp(); |
[&](AMDAIE::NpuDmaWaitOp waitOp) { | ||
bool toErase = true; | ||
for (Value token : waitOp.getAsyncTokens()) { | ||
if (auto npuHalfDmaCpyNdOp = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still kinda think a function at the level
FailureOr canFoldBasedOnNpuHalfDmaCpyNdOp(...)
would make for slightly easier to read (less indented) code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, made it a function now
uint32_t row = getConstantIndexOrAssert(tileOp.getRow()); | ||
uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row); | ||
// Keep wait op if, either reaches the maximum queue size, or there | ||
// is a duplicate BD ID in the same tile. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// is a duplicate BD ID in the same tile. | |
// is a duplicate BD ID in the same tile, or packet flow, or the queue is empty |
?
// ----- | ||
|
||
// Expect no DMA waits to be folded, since the same BD ID is used. | ||
// CHECK-LABEL: @fold_dma_waits_same_bd_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could just be
CHECK-COUNT-2: dma_wait
CHECK-NOT: dma_wait
I'm unblocking this, letting Jorn accept / reject as I don't have enough context to know if this is good to land
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
) This is an enhancement for #962. In the previous PR, DMA waits on the same `connection` (and the same tile) could be folded, exploiting the fact that each DMA channel has a queue size of 4. In this PR, DMA waits across multiple `columns` can also be folded, provided their corresponding `row`, `channel`, and `direction` are the same. This optimization leverages the ability to specify `colNum` in `TCTSync`, where the range `[col, col + colNum)` can be addressed. The numbers in the following table show the instruction size in words. | Test (MxKxN) | No Folding | Only Fold by Connection | Only Fold by Column | Fold Both | |---------------|------------|--------------------|----------------|-----------| | 512x4096x512 | 1228 | 1132 | 1120 | 1096 | | 512x512x4096 | 820 | 772 | 748 | 736 | | 4096x512x512 | 4628 | 4244 | 4220 | 4124 |
Each DMA channel has a task queue with the depth of 4. DMA wait is only required for every 4 pushes, reducing unnecessary synchronization.
Example: https://gist.github.com/Yu-Zhewen/5f569b56c7b1f1a8715a7c4c3bf9e609
Results compared to 7c4b985:
This optimization is orthogonal to DMA chaining #931.