Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize softmax #69

Closed
wants to merge 7 commits into from
Closed

Optimize softmax #69

wants to merge 7 commits into from

Conversation

iclementine
Copy link
Collaborator

@iclementine iclementine commented Jun 14, 2024

  1. use different kernels (inner & non_inner) for softmax forward
    a. inner: for reduction the last dim(and the input is preprocessed to be contiguous)
    a. inner: for reduce along other dimensions(and the input is preprocessed to be contiguous)
  2. both have ONE_TILE_PER_CTA static condition
    a. when ONE_TILE_PER_CTA is True, load only one tile per cta without looping over reduction dim
    b. when ONE_TILE_PER_CTA is False, use online softmax normalizer algorithm to save one swipe over the input.

We can leave other optimizations(optimize accordign to input layout or two-pass-reduction) for future PRs

1. ensure that decorator cascading is working as expected, i.e. inner decorator can use arguments provided by outer decorator
2. ensure that grid function can use all the arguments provided by decorators(Autotuner & Heuristics)
3. simply LibEntry, extract captured constant arguments from CompiledKernel, instead of traversing layers of decorator.
…e ONE_TILE_PER_CTA static condition(to decide whether to load only one tile per cta.
…e ONE_TILE_PER_CTA static condition(to decide whether to load only one tile per cta.
@StrongSpoon
Copy link
Collaborator

need to rebase after merging pr68

@@ -535,8 +535,8 @@ def _torch_rms_norm(x, residual, weight, eps):

@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_softmax(shape, dtype):
dim = 1
@pytest.mark.parametrize("dim", [0, 1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DIM_LIST

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"dim" passed to the function here is an integer.
Do you mean we should add test for non-inner reduction in general?

@iclementine
Copy link
Collaborator Author

need to rebase after merging pr68

I didn't resolve the conflicts properly. So I opened a new pull request #76 .

@iclementine iclementine deleted the optimize_softmax branch June 19, 2024 03:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants