Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Median #3512

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open

Implement Median #3512

wants to merge 17 commits into from

Conversation

anhskrttt
Copy link
Collaborator

@anhskrttt anhskrttt commented Feb 14, 2025

  • Add Median operation with forward and backward kernel.
    • This op basically reuses kthvalue operation.
  • Add driver and gtest for kernel.
  • Performance condition: MIOpen performs better if the following are applied
    • input_num_dims > 1
    • non contiguous input
    • selected dim size (i.e. reduce_size) > 250 (for bwd) or > 300 (for fwd)
    • selected dim stride = 1

Average improvement over ROCm

type fwd bwd
float 3.08 3.69
float16 2.88 3.74
bfloat16 3.28 3.74

Detail Benchmark

fp32
dtype input_size contiguous dim direction ROCm MIOpen Improvement
float32 [16384 2 8 10 10] FALSE 0 bwd 2070063 104622 19.79
float32 [1000 32768] FALSE 0 fwd 21262222 2346920 9.06
float32 [16384 25 5 2] FALSE 0 bwd 128479 20675 6.21
float32 [1000 16384] FALSE 0 fwd 6543625 1183350 5.53
float32 [32768 25 5 2] FALSE 0 fwd 3733089 708617 5.27
float32 [1000 2048] FALSE 0 fwd 606875 171945 3.53
float32 [2048 25 5 2] FALSE 0 bwd 26720 7875 3.39
float32 [1000 1024] FALSE 0 bwd 46560 13813 3.37
float32 [4096 25 5 2] FALSE 0 fwd 314878 114755 2.74
float32 [500 512] FALSE 0 bwd 21280 9048 2.35
float32 [512 25 5 2] FALSE 0 bwd 16480 7324 2.25
float32 [500 2048] FALSE 0 fwd 293437 137617 2.13
float32 [500 512] FALSE 0 fwd 90559 50275 1.80
float32 [8192 60] FALSE 0 fwd 188319 113262 1.66
float32 [512 20] FALSE 0 fwd 30720 33813 0.91
fp16
dtype input_size contiguous dim direction ROCm MIOpen Improvement
float16 [16384 2 8 10 10] FALSE 0 bwd 1090232 56408 19.33
float16 [32768 25 5 2] FALSE 0 bwd 222719 23946 9.30
float16 [1000 32768] FALSE 0 fwd 11169348 1880860 5.94
float16 [16384 70] FALSE 0 bwd 40640 9262 4.39
float16 [1000 2048] FALSE 0 fwd 512956 136888 3.75
float16 [2048 25 5 2] FALSE 0 bwd 26080 7857 3.32
float16 [1000 512] FALSE 0 bwd 30720 9671 3.18
float16 [1000 4096] FALSE 0 bwd 133438 45866 2.91
float16 [1000 8192] FALSE 0 bwd 243998 87110 2.80
float16 [500 512] FALSE 0 bwd 22880 8711 2.63
float16 [500 32768] FALSE 0 fwd 3619330 1509480 2.40
float16 [500 4096] FALSE 0 fwd 481437 207251 2.32
float16 [1024 25 5 2] FALSE 0 fwd 87519 37262 2.35
float16 [500 1024] FALSE 0 bwd 33440 13528 2.47
float16 [4096 50] FALSE 0 fwd 78720 58559 1.34
bfp16
dtype input_size contiguous dim direction ROCm MIOpen Improvement
bfloat16 [16384 2 8 10 10] FALSE 0 bwd 1094231 56693 19.30
bfloat16 [32768 25 5 2] FALSE 0 bwd 223839 23608 9.48
bfloat16 [16384 25 5 2] FALSE 0 bwd 115358 13439 8.58
bfloat16 [1000 32768] FALSE 0 fwd 12844534 1940770 6.62
bfloat16 [32768 80] FALSE 0 bwd 78399 12266 6.39
bfloat16 [16384 2 8 10 10] FALSE 0 fwd 6902023 1102280 6.26
bfloat16 [32768 25 5 2] FALSE 0 fwd 2420780 437313 5.54
bfloat16 [16384 25 5 2] FALSE 0 fwd 1157751 223625 5.18
bfloat16 [1000 8192] FALSE 0 fwd 2177902 498147 4.37
bfloat16 [1000 4096] FALSE 0 fwd 1131671 259340 4.36
bfloat16 [16384 70] FALSE 0 bwd 40799 9351 4.36
bfloat16 [1000 16384] FALSE 0 fwd 4217725 979104 4.31
bfloat16 [1000 2048] FALSE 0 fwd 586235 138328 4.24
bfloat16 [1000 32768] FALSE 0 bwd 1340949 336407 3.99
bfloat16 [1000 1024] FALSE 0 fwd 308957 80141 3.86

Comment on lines +51 to +60
bool IsImprovementOverROCm(const miopen::median::BwdProblemDescription& problem)
{
auto dim = problem.GetDim();
auto input_grad_lengths = problem.GetInputGradDesc().GetLengths();
auto dim_size = input_grad_lengths[dim];
auto dim_stride = problem.GetInputGradDesc().GetStrides()[dim];

return input_grad_lengths.size() > 1 && !problem.IsAllContiguous() && dim_size > 250 &&
dim_stride == 1;
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to unnamed namespace

Comment on lines +51 to +60
bool IsImprovementOverROCm(const miopen::median::FwdProblemDescription& problem)
{
auto dim = problem.GetDim();
auto input_lengths = problem.GetInputDesc().GetLengths();
auto dim_size = input_lengths[dim];
auto dim_stride = problem.GetInputDesc().GetStrides()[dim];

return input_lengths.size() > 1 && !problem.IsAllContiguous() && dim_size > 300 &&
dim_stride == 1;
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to unnamed namespace

@anhskrttt anhskrttt marked this pull request as draft February 16, 2025 15:19
@anhskrttt anhskrttt marked this pull request as ready for review February 18, 2025 03:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant