-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multinomial #141
Multinomial #141
Conversation
* exponential added. * Added K-S tests to exponential_, fp64 corrected. * aligned with aten prototype * Exponential_ uses uint64 offsets in Triton kernel. * Update pyproject config for new test dependencies.
1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max; 2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size
* libentry now is lock protected. * Add multithreading tests for libentry. * polish code.
[Test] Test for op
…onflicts with master.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why chi-square proves the accuracy.
Also added fused_norm_cumsum for better perf. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
finished
src/flag_gems/ops/cumsum.py
Outdated
|
||
def fused_renorm_cumsum(inp, dim=-1): | ||
logging.debug("GEMS RENORM_CUMSUM") | ||
assert inp.dtype in (torch.float16, torch.float32, torch.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allow inp.dtype to be torch.bfloat16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
This multinomial is a Triton conversion from the Pytorch counterpart.
Performance figure updated.