-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize softmax #69
Optimize softmax #69
Conversation
1. ensure that decorator cascading is working as expected, i.e. inner decorator can use arguments provided by outer decorator 2. ensure that grid function can use all the arguments provided by decorators(Autotuner & Heuristics) 3. simply LibEntry, extract captured constant arguments from CompiledKernel, instead of traversing layers of decorator.
…e ONE_TILE_PER_CTA static condition(to decide whether to load only one tile per cta.
…e ONE_TILE_PER_CTA static condition(to decide whether to load only one tile per cta.
need to rebase after merging pr68 |
@@ -535,8 +535,8 @@ def _torch_rms_norm(x, residual, weight, eps): | |||
|
|||
@pytest.mark.parametrize("shape", REDUCTION_SHAPES) | |||
@pytest.mark.parametrize("dtype", FLOAT_DTYPES) | |||
def test_accuracy_softmax(shape, dtype): | |||
dim = 1 | |||
@pytest.mark.parametrize("dim", [0, 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DIM_LIST
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"dim" passed to the function here is an integer.
Do you mean we should add test for non-inner reduction in general?
I didn't resolve the conflicts properly. So I opened a new pull request #76 . |
a. inner: for reduction the last dim(and the input is preprocessed to be contiguous)
a. inner: for reduce along other dimensions(and the input is preprocessed to be contiguous)
ONE_TILE_PER_CTA
static conditiona. when
ONE_TILE_PER_CTA
is True, load only one tile per cta without looping over reduction dimb. when
ONE_TILE_PER_CTA
is False, use online softmax normalizer algorithm to save one swipe over the input.We can leave other optimizations(optimize accordign to input layout or two-pass-reduction) for future PRs