Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinomial #141

Merged
merged 56 commits into from
Sep 2, 2024
Merged

Multinomial #141

merged 56 commits into from
Sep 2, 2024

Conversation

tongxin
Copy link
Contributor

@tongxin tongxin commented Jul 30, 2024

This multinomial is a Triton conversion from the Pytorch counterpart.

Performance figure updated.

benchmark/test_special_perf.py Operator multinomial Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.411968            0.244064
6144                  0.690048            0.449952
11264                 0.931296             0.53248
16384                  1.13818            0.664704
21504                  1.23549            0.677472
26624                  1.39875            0.745216
31744                  1.54067             0.80096
36864                  1.74266            0.934592
41984                   1.8904            0.965376
47104                  2.07027             1.03923
52224                  2.23818             1.08048
57344                  2.45094             1.19658
62464                  2.59354             1.19482
67584                  2.71466             1.29584
72704                  2.92397              1.3409
77824                  3.11814             1.42378
Operator multinomial Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.451392            0.267328
6144                   0.80832            0.466176
11264                  1.08074            0.544896
16384                  1.33142            0.743872
21504                  1.45101            0.778528
26624                  1.65642            0.899328
31744                  1.85642            0.982752
36864                  2.11466             1.14086
41984                  2.31005             1.20314
47104                  2.52592             1.30464
52224                  2.71603             1.37322
57344                  2.94445             1.50336
62464                  3.13027             1.54714
67584                  3.29763             1.67389
72704                  3.50435             1.76499
77824                  3.69504             1.85344
.

@tongxin tongxin marked this pull request as draft July 30, 2024 13:14
@tongxin tongxin marked this pull request as ready for review August 1, 2024 04:40
Bowen12992 and others added 20 commits August 1, 2024 14:30
* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.
 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size
* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why chi-square proves the accuracy.

tests/test_special_ops.py Outdated Show resolved Hide resolved
benchmark/test_special_perf.py Outdated Show resolved Hide resolved
src/flag_gems/ops/multinomial.py Outdated Show resolved Hide resolved
tests/test_distribution_ops.py Outdated Show resolved Hide resolved
@StrongSpoon StrongSpoon self-assigned this Aug 19, 2024
@tongxin
Copy link
Contributor Author

tongxin commented Aug 19, 2024

Also added fused_norm_cumsum for better perf.

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finished

benchmark/test_special_perf.py Show resolved Hide resolved
tests/test_special_ops.py Show resolved Hide resolved
tests/test_special_ops.py Outdated Show resolved Hide resolved
tests/test_special_ops.py Outdated Show resolved Hide resolved
tests/test_special_ops.py Outdated Show resolved Hide resolved

def fused_renorm_cumsum(inp, dim=-1):
logging.debug("GEMS RENORM_CUMSUM")
assert inp.dtype in (torch.float16, torch.float32, torch.float64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow inp.dtype to be torch.bfloat16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@tongxin tongxin merged commit 2f191fe into master Sep 2, 2024
4 checks passed
@tongxin tongxin deleted the multinomial branch September 2, 2024 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants