Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a pass to fold DMA waits #962

Merged
merged 11 commits into from
Dec 9, 2024
Merged

Add a pass to fold DMA waits #962

merged 11 commits into from
Dec 9, 2024

Conversation

Yu-Zhewen
Copy link
Contributor

@Yu-Zhewen Yu-Zhewen commented Dec 5, 2024

Each DMA channel has a task queue with the depth of 4. DMA wait is only required for every 4 pushes, reducing unnecessary synchronization.

Example: https://gist.github.com/Yu-Zhewen/5f569b56c7b1f1a8715a7c4c3bf9e609

Results compared to 7c4b985:

Test (MxKxN) Instruction Size Before (Words) Instruction Size After (Words)
512x4096x512 1228 1132
512x512x4096 820 772
4096x512x512 4628 4244

This optimization is orthogonal to DMA chaining #931.

@Yu-Zhewen Yu-Zhewen changed the title Add a pass to simplify DMA waits Add a pass to fold DMA waits Dec 5, 2024
newling
newling previously requested changes Dec 5, 2024
}
}

// -----
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments about what is being tested in each case. I can see there are fewer dma_wait operations after the pass, but it's not clear to me which ones are being removed. Also a bit surprised that there are no CHECK-NOT or CHECK-NEXT statements.

@@ -378,6 +378,8 @@ struct AMDAIEDeviceModel {
DenseMap<uint32_t, SmallVector<uint32_t>> getChannelToValidBdIds(
AMDAIETileType tileType) const;

uint8_t getDmaMaxQueueSize(uint8_t col, uint8_t row);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really unfortunate that this can't be const. I see that getTileType uses a const_cast to workaround the lack of const-correctness in aie_rt, seems like first done here

6c4f905#diff-17008229092a63d5df9831105108aa99c799a226441b0dd9d8327708e57fc2aeR218

I tracked this down because I saw you passing AMDAIE::AMDAIEDeviceModel deviceModel by value, which isn't great as this isn't a pointer type so there is overhead. But if getDmaMaxQueueSize isn't const, deviceModel can't be const ref where you use it...

<< "expected to operate on an `amdaie.flow`";
return WalkResult::interrupt();
}
if (maybeFlowOp->getIsPacketFlow()) return WalkResult::advance();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This advance -- so you skip to the next waitOp, effectively making toErase = false for this waitOp? I think the logic would be easier to follow if you made standalone functions. Maybe a function like

LogicalResult canFoldWaitOp(WaitOp waitOp, AMDAIE::AMDAIEDeviceModel deviceModel, ...) {

}

Comment on lines +91 to +92
// CHECK: %[[TOKEN_0:.+]] = amdaie.npu.half_dma_cpy_nd async %[[CONNECTION]](%[[OBJECT_FIFO_1]] [] [] [] bd_id = %[[BD_ID_0]] channel = %[[CHANNEL_0]]) : !amdaie.logicalobjectfifo<memref<2048xi32>>
// CHECK: amdaie.npu.dma_wait(%[[TOKEN_0]] : !amdaie.async_token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have a wait on the first dma_cpy_nd?

I think we should expect:

dma_cpy_nd
dma_cpy_nd
dma_cpy_nd
%0 = dma_cpy_nd
dma_wait(%0)

Instead of:

%0 = dma_cpy_nd
dma_wait(%0)
dma_cpy_nd
dma_cpy_nd
%1 = dma_cpy_nd
dma_wait(%0)

Copy link
Contributor Author

@Yu-Zhewen Yu-Zhewen Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am traversing controlcode in reverse order, and the example actually is:

%0 = dma_cpy_nd
dma_wait(%0)
dma_cpy_nd
dma_cpy_nd
dma_cpy_nd
%1 = dma_cpy_nd
dma_wait(%1)

there are four dma_cpy_nd ops between.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I missed the fourth op between the waits, but we still shouldn't have a wait on the first op. This does matter if the number of dma_cpy_nd ops is smaller or equal to 4

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, not only for nb_dma_cpy_nd_ops <= 4, but also in the example above, you're using two waits in what could be implemented with one.

Copy link
Contributor Author

@Yu-Zhewen Yu-Zhewen Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if we don't have a wait on the first op, then we will have to do this:

dma_cpy_nd
dma_cpy_nd
dma_cpy_nd
%0 = dma_cpy_nd
dma_wait(%0)
%1 = dma_cpy_nd
dma_wait(%1)

we still have two waits, since anyway we need one wait at the end of controlcode? That's why I choose traverse in reverse order.

Copy link
Collaborator

@jtuyls jtuyls Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, I need to check if the current wait is the last one for each connection

You can traverse in forward order and keep track of every 4th and last DMA on a connection, then in a second pass, you can only keep those waits. I don't see how that's more complex?

Copy link
Collaborator

@jtuyls jtuyls Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I think this is important is that the output IR should be as one would intuitively expect. A lot of the time I have to debug issues by just reading and understanding the IR and non-intuitive output is no fun when you're doing that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried something, but then I realized that traversing in forward order complicates the management of BD IDs. When encountering a duplicate BD ID, we need to keep the wait for the last DMA whichused that BD ID. Also, all the erasion decisions between those two DMAs need to be updated accordingly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, could you add documentation to the function on why you iterate in reverse order and an example to show expected output for one of these more quirky cases above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added now

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments in the lit tests, they helped me

AMDAIE::ConnectionOp connectionOp = maybeConnectionOp.value();
// Retrieve the flow op.
std::optional<AMDAIE::FlowOp> maybeFlowOp =
maybeConnectionOp->getFlowOp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
maybeConnectionOp->getFlowOp();
connectionOp.getFlowOp();

[&](AMDAIE::NpuDmaWaitOp waitOp) {
bool toErase = true;
for (Value token : waitOp.getAsyncTokens()) {
if (auto npuHalfDmaCpyNdOp =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still kinda think a function at the level

FailureOr canFoldBasedOnNpuHalfDmaCpyNdOp(...)

would make for slightly easier to read (less indented) code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, made it a function now

uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
uint32_t maxQueueSize = deviceModel.getDmaMaxQueueSize(col, row);
// Keep wait op if, either reaches the maximum queue size, or there
// is a duplicate BD ID in the same tile.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// is a duplicate BD ID in the same tile.
// is a duplicate BD ID in the same tile, or packet flow, or the queue is empty

?

// -----

// Expect no DMA waits to be folded, since the same BD ID is used.
// CHECK-LABEL: @fold_dma_waits_same_bd_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just be
CHECK-COUNT-2: dma_wait
CHECK-NOT: dma_wait

@newling newling dismissed their stale review December 6, 2024 21:13

I'm unblocking this, letting Jorn accept / reject as I don't have enough context to know if this is good to land

Copy link
Collaborator

@jtuyls jtuyls left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Yu-Zhewen Yu-Zhewen enabled auto-merge (squash) December 9, 2024 21:23
@Yu-Zhewen Yu-Zhewen merged commit 2243dd8 into main Dec 9, 2024
8 checks passed
@Yu-Zhewen Yu-Zhewen deleted the zhewen_remove_wait branch December 9, 2024 22:53
Yu-Zhewen added a commit that referenced this pull request Dec 18, 2024
)

This is an enhancement for
#962.

In the previous PR, DMA waits on the same `connection` (and the same
tile) could be folded, exploiting the fact that each DMA channel has a
queue size of 4.

In this PR, DMA waits across multiple `columns` can also be folded,
provided their corresponding `row`, `channel`, and `direction` are the
same. This optimization leverages the ability to specify `colNum` in
`TCTSync`, where the range `[col, col + colNum)` can be addressed.

The numbers in the following table show the instruction size in words.
| Test (MxKxN) | No Folding | Only Fold by Connection | Only Fold by
Column | Fold Both |

|---------------|------------|--------------------|----------------|-----------|
| 512x4096x512 | 1228 | 1132 | 1120 | 1096 |
| 512x512x4096 | 820 | 772 | 748 | 736 |
| 4096x512x512 | 4628 | 4244 | 4220 | 4124 |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants