Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SWP] Fix a bug in SWP that did not correctly compute the number of loop iterations #4887

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sfzhu93
Copy link
Contributor

@sfzhu93 sfzhu93 commented Oct 10, 2024

This PR is picked from #4689 for clarity. Prior to this PR, the loop lower and outer bound for tl.range was incorrect. Now it correctly casts to the data type that holds the bounds in tl.range. Also, prior to this PR, the maxStage was set to zero, and we correctly set it to the value come from num_stages in tl.range.

It is likely that it's directly inherited from the MLIR upstream pipelining algorithm: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp#L110. However, seems Triton uses different data types from the MLIR upstream algorithm.

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /python/test for end-to-end tests
  • Select one of the following.

    • I have not added any lit tests.

int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
if (numIteration > maxStage) {
if (numIteration > options.numStages - 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what case maxStage is different than options.numStages - 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this line, maxStage is 0, which is different from options.numStages - 1. This change was made per Pawel's comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if maxStage is 0 numStages is 1. Pawel's comment explains it: maxStage == numStages-1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. In that case, did this line check if there's at least one iteration? Do we still need to check the number of iterations larger than num_stages from tt level?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if there is 0 iterations this condition is always false. I'm not sure I understand the questions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem seems to be that at this point maxStage is initialized to zero but not set yet, maxStage is set later on after we have a schedule. This diff looks generally okay to me. But it caused divergence from upstream mlir.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right, we should move the calculation of maxStage up

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the problem is that the actual maxStage depends on the pipeline schedule and here we are trying to early exit before getting a schedule, when we know the iteration count is too small for pipelining to be beneficial. So we are using numStages - 1 to check for early exit here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need further improvement on this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR looks okay to me. Is it possible to upload part of patch to upstream mlir?

@sfzhu93 sfzhu93 force-pushed the swp-num-stages-bug branch from 2ccc21b to c4279df Compare October 11, 2024 18:46
…gic (triton-lang#4880)

In this PR, we add one unit test that triggers an error reported during
software pipelining.

The testing code for remarks and debug is also improved. Previously, the
environment variables are not reset back to default value when assertion
fails, which will impact subsequent unit tests. In this PR, we use
`with` statement and context manager to ensure that environment
variables are always correctly reset after each test, regardless of
whether the test passes or fails. In addition, a fixture
`fresh_triton_cache` is added to ensure the kernel is recompiled every
time the test runs. These changes make the test code easier to run and
the testing process more robust.

This PR is one of the changes in triton-lang#4689. I create this new PR for clarity
and resolve rebasing conflicts.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/python/test` for end-to-end tests

- Select one of the following.
  - [x] I have not added any `lit` tests.
@sfzhu93 sfzhu93 force-pushed the swp-num-stages-bug branch from c4279df to 1afd5ca Compare October 11, 2024 18:47
@sfzhu93 sfzhu93 changed the title [WIP][SWP] Fix a bug in SWP that did not correctly compute the number of loop iterations (Depends on #4880) [WIP][SWP] Fix a bug in SWP that did not correctly compute the number of loop iterations Oct 11, 2024
@sfzhu93 sfzhu93 changed the title [WIP][SWP] Fix a bug in SWP that did not correctly compute the number of loop iterations [SWP] Fix a bug in SWP that did not correctly compute the number of loop iterations Oct 17, 2024
int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
if (numIteration > maxStage) {
if (numIteration > options.numStages - 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem seems to be that at this point maxStage is initialized to zero but not set yet, maxStage is set later on after we have a schedule. This diff looks generally okay to me. But it caused divergence from upstream mlir.

@sfzhu93 sfzhu93 marked this pull request as ready for review October 18, 2024 20:24
@sfzhu93 sfzhu93 requested a review from ptillet as a code owner October 18, 2024 20:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants