Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve reduce performance by passing CartesianIndices and length statically #100

Merged
merged 2 commits into from
Feb 22, 2023

Conversation

maxwindiff
Copy link
Contributor

@maxwindiff maxwindiff commented Feb 21, 2023

Improve indexing performance by passing CartesianIndices statically, using a similar trick as JuliaGPU/GPUArrays.jl#454. Still slow, but not as bad as before. Helps with #46.

Before:

julia> a = fill(Float32(1.0), 4096 * 4096);
julia> da = MtlArray(a);
julia> b = fill(Float32(1.0), 4096, 4096);
julia> db = MtlArray(b);

julia> @btime sum(a)
  1.393 ms (1 allocation: 16 bytes)
1.6777216f7

julia> @btime sum(b)
  1.392 ms (1 allocation: 16 bytes)
1.6777216f7

julia> @btime sum(da)
  4.026 ms (868 allocations: 23.95 KiB)
1.6777216f7

julia> @btime sum(db)
  11.196 ms (873 allocations: 25.23 KiB)
1.6777216f7

After:

julia> @btime sum(da)
  1.811 ms (754 allocations: 20.80 KiB)
1.6777216f7

julia> @btime sum(db)
  2.181 ms (759 allocations: 21.33 KiB)
1.6777216f7

Passing length(Rother) as Rlen may look redundant, but the 2D case (sum(db)) runs 3x slower without it.

julia> @btime sum(db)
  6.648 ms (759 allocations: 21.33 KiB)
1.6777216f7

There were some test failures, but they also happen on main (complains about symbol not found) and seems unrelated to this PR -- https://gist.github.com/maxwindiff/fe0480dcfd1bcd4cb28e91f2c1a0cfa6

src/mapreduce.jl Outdated Show resolved Hide resolved
@maleadt
Copy link
Member

maleadt commented Feb 21, 2023

LGTM, for now at least. This isn't something we want to apply everywhere due to the increased compile times, it's better to figure out a way to encode dynamic Cartesian indices in a way that Metal can handle them somewhat performantly.

@maleadt
Copy link
Member

maleadt commented Feb 21, 2023

Did you explore adding back some of the information that gets lost by @inbounds? That improved performance significantly in JuliaGPU/GPUArrays.jl#454.

Co-authored-by: Tim Besard <[email protected]>
@maxwindiff
Copy link
Contributor Author

The linear indexing at https://github.com/JuliaGPU/Metal.jl/blob/main/src/mapreduce.jl#L105 and https://github.com/JuliaGPU/Metal.jl/blob/main/src/mapreduce.jl#L120 were guarded by range checks already. What other bounds info should I try?

I tried this but there's no improvement:

diff --git a/src/mapreduce.jl b/src/mapreduce.jl
index e878851..29010ae 100644
--- a/src/mapreduce.jl
+++ b/src/mapreduce.jl
@@ -96,6 +96,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
     # and possibly groups if it doesn't fit) and other elements (remaining groups)
     localIdx_reduce = thread_position_in_threadgroup_1d()
     localDim_reduce = threads_per_threadgroup_1d()
+    assume(1 <= Rlen)
     groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid_1d(), Rlen)
     groupDim_reduce = threadgroups_per_grid_1d() ÷ Rlen
 
@@ -103,6 +104,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
     # (that means we can safely synchronize items within this group)
     iother = groupIdx_other
     @inbounds if iother <= length(Rother)
+        assume(1 <= iother <= length(Rother))
         Iother = Rother[iother]
 
         # load the neutral value
@@ -118,6 +120,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
         # reduce serially across chunks of input vector that don't fit in a group
         ireduce = localIdx_reduce + (groupIdx_reduce - 1) * localDim_reduce
         while ireduce <= length(Rreduce)
+            assume(1 <= ireduce <= length(Rreduce))
             Ireduce = Rreduce[ireduce]
             J = max(Iother, Ireduce)
             val = op(val, f(_map_getindex(As, J)...))

@maleadt
Copy link
Member

maleadt commented Feb 22, 2023

Yeah I guess that covers all of them already.

@maleadt maleadt merged commit 25a7930 into JuliaGPU:main Feb 22, 2023
@maxwindiff maxwindiff deleted the reduce branch February 26, 2023 07:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Gotta go fast.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants