Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dropout performance by unrolling 4xint random masks. #85

Merged
merged 6 commits into from
Jul 4, 2024

Conversation

tongxin
Copy link
Contributor

@tongxin tongxin commented Jun 27, 2024

This PR applies the following changes to dropout.

  • Uses the 128bit Philox output to generate four dropout masks instead of one in the current version.
  • Replaces libentry and autotune with simple triton.heuristics for better API performance.

Improved performance on A100 is show below.
dropout-performance:
M torch gems
0 5.000000e+02 0.199681 0.230627
1 4.096000e+03 1.646302 2.031746
2 1.638400e+04 8.094861 7.757576
3 6.553600e+04 28.151202 31.148290
4 1.000000e+05 40.716611 46.992483
5 1.000000e+09 952.747469 1326.879422

return x * scale


UNROLL = 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since UNROLL=4 is in each jit function, I suggest deleting this.

Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

src/flag_gems/ops/dropout.py Show resolved Hide resolved
src/flag_gems/ops/dropout.py Show resolved Hide resolved
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@StrongSpoon StrongSpoon merged commit 802b605 into master Jul 4, 2024
3 checks passed
@tongxin tongxin deleted the dropout branch July 10, 2024 02:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants