Skip to content

Commit

Permalink
Improve reduce performance by passing CartesianIndices and length sta…
Browse files Browse the repository at this point in the history
…tically (#100)

Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
maxwindiff and maleadt authored Feb 22, 2023
1 parent 9cafba7 commit 25a7930
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ Base.@propagate_inbounds _map_getindex(args::Tuple{}, I) = ()
# Reduce an array across the grid. All elements to be processed can be addressed by the
# product of the two iterators `Rreduce` and `Rother`, where the latter iterator will have
# singleton entries for the dimensions that should be reduced (and vice versa).
function partial_mapreduce_device(f, op, neutral, maxthreads, Rreduce, Rother, shuffle, R, As...)
function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
::Val{Rother}, ::Val{Rlen}, shuffle, R, As...) where {Rreduce, Rother, Rlen}
# decompose the 1D hardware indices into separate ones for reduction (across items
# and possibly groups if it doesn't fit) and other elements (remaining groups)
localIdx_reduce = thread_position_in_threadgroup_1d()
localDim_reduce = threads_per_threadgroup_1d()
groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid_1d(), length(Rother))
groupDim_reduce = threadgroups_per_grid_1d() ÷ length(Rother)
groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid_1d(), Rlen)
groupDim_reduce = threadgroups_per_grid_1d() ÷ Rlen

# group-based indexing into the values outside of the reduction dimension
# (that means we can safely synchronize items within this group)
Expand Down Expand Up @@ -171,8 +172,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# that's why each threads also loops across their inputs, processing multiple values
# so that we can span the entire reduction dimension using a single item group.
# XXX: can we query the 1024?
kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(1024), Rreduce,
Rother, Val(shuffle), R′, A)
kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(1024), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(shuffle), R′, A)
pipeline = MtlComputePipelineState(kernel.fun.lib.device, kernel.fun)

# how many threads do we want?
Expand Down Expand Up @@ -206,7 +207,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
if reduce_groups == 1
# we can cover the dimensions to reduce using a single group
@metal threads=threads grid=groups partial_mapreduce_device(
f, op, init, Val(threads), Rreduce, Rother, Val(shuffle), R′, A)
f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(shuffle), R′, A)
else
# we need multiple steps to cover all values to reduce
partial = similar(R, (size(R)..., reduce_groups))
Expand All @@ -220,7 +221,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
end
end
@metal threads=threads grid=groups partial_mapreduce_device(
f, op, init, Val(threads), Rreduce, Rother, Val(shuffle), partial, A)
f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(shuffle), partial, A)

GPUArrays.mapreducedim!(identity, op, R′, partial; init=init)
end
Expand Down

0 comments on commit 25a7930

Please sign in to comment.