Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve reduce performance by passing CartesianIndices and length statically #100

Merged
merged 2 commits into from
Feb 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ Base.@propagate_inbounds _map_getindex(args::Tuple{}, I) = ()
# Reduce an array across the grid. All elements to be processed can be addressed by the
# product of the two iterators `Rreduce` and `Rother`, where the latter iterator will have
# singleton entries for the dimensions that should be reduced (and vice versa).
function partial_mapreduce_device(f, op, neutral, maxthreads, Rreduce, Rother, shuffle, R, As...)
function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
::Val{Rother}, ::Val{Rlen}, shuffle, R, As...) where {Rreduce, Rother, Rlen}
# decompose the 1D hardware indices into separate ones for reduction (across items
# and possibly groups if it doesn't fit) and other elements (remaining groups)
localIdx_reduce = thread_position_in_threadgroup_1d()
localDim_reduce = threads_per_threadgroup_1d()
groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid_1d(), length(Rother))
groupDim_reduce = threadgroups_per_grid_1d() ÷ length(Rother)
groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid_1d(), Rlen)
groupDim_reduce = threadgroups_per_grid_1d() ÷ Rlen

# group-based indexing into the values outside of the reduction dimension
# (that means we can safely synchronize items within this group)
Expand Down Expand Up @@ -171,8 +172,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# that's why each threads also loops across their inputs, processing multiple values
# so that we can span the entire reduction dimension using a single item group.
# XXX: can we query the 1024?
kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(1024), Rreduce,
Rother, Val(shuffle), R′, A)
kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(1024), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(shuffle), R′, A)
pipeline = MtlComputePipelineState(kernel.fun.lib.device, kernel.fun)

# how many threads do we want?
Expand Down Expand Up @@ -206,7 +207,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
if reduce_groups == 1
# we can cover the dimensions to reduce using a single group
@metal threads=threads grid=groups partial_mapreduce_device(
f, op, init, Val(threads), Rreduce, Rother, Val(shuffle), R′, A)
f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(shuffle), R′, A)
else
# we need multiple steps to cover all values to reduce
partial = similar(R, (size(R)..., reduce_groups))
Expand All @@ -220,7 +221,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
end
end
@metal threads=threads grid=groups partial_mapreduce_device(
f, op, init, Val(threads), Rreduce, Rother, Val(shuffle), partial, A)
f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(shuffle), partial, A)

GPUArrays.mapreducedim!(identity, op, R′, partial; init=init)
end
Expand Down