Skip to content

Commit

Permalink
Reduce multiple consecutive values in each thread to improve efficien…
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwindiff authored Mar 9, 2023
1 parent c19b940 commit af6f7c4
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,12 @@ Base.@propagate_inbounds _map_getindex(args::Tuple{}, I) = ()
# product of the two iterators `Rreduce` and `Rother`, where the latter iterator will have
# singleton entries for the dimensions that should be reduced (and vice versa).
function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
::Val{Rother}, ::Val{Rlen}, shuffle, R, As...) where {Rreduce, Rother, Rlen}
::Val{Rother}, ::Val{Rlen}, ::Val{grain}, shuffle, R, As...) where {Rreduce, Rother, Rlen, grain}
# decompose the 1D hardware indices into separate ones for reduction (across items
# and possibly groups if it doesn't fit) and other elements (remaining groups)
localIdx_reduce = thread_position_in_threadgroup_1d()
localDim_reduce = threads_per_threadgroup_1d()
localDim_reduce = threads_per_threadgroup_1d() * grain
groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid_1d(), Rlen)
groupDim_reduce = threadgroups_per_grid_1d() ÷ Rlen

# group-based indexing into the values outside of the reduction dimension
# (that means we can safely synchronize items within this group)
Expand All @@ -115,13 +114,19 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},

val = op(neutral, neutral)

# reduce serially across chunks of input vector that don't fit in a group
ireduce = localIdx_reduce + (groupIdx_reduce - 1) * localDim_reduce
while ireduce <= length(Rreduce)
Ireduce = Rreduce[ireduce]
J = max(Iother, Ireduce)
val = op(val, f(_map_getindex(As, J)...))
ireduce += localDim_reduce * groupDim_reduce
# read multiple consecutive values in reduction dimension to improve efficiency
ireduce = (localIdx_reduce - 1) * grain + (groupIdx_reduce - 1) * localDim_reduce
limit = ireduce + grain
while ireduce < limit
ireduce += 1
next = if ireduce <= length(Rreduce)
Ireduce = Rreduce[ireduce]
J = max(Iother, Ireduce)
f(_map_getindex(As, J)...)
else
neutral
end
val = op(val, next)
end

val = reduce_group(op, val, neutral, shuffle, maxthreads)
Expand Down Expand Up @@ -166,14 +171,21 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# but allows us to write a generalized kernel supporting partial reductions.
R′ = reshape(R, (size(R)..., 1))

# when the reduction dimension is contiguous in memory, we can improve performance
# by having each thread read multiple consecutive elements. base on experiments,
# 16 / sizeof(T) elements is usually a good choice.
reduce_dim_start = something(findfirst(axis -> length(axis) > 1, axes(Rreduce)), 1)
contiguous = prod(size(R)[1:reduce_dim_start-1]) == 1
grain = contiguous ? prevpow(2, cld(16, sizeof(T))) : 1

# how many threads can we launch?
#
# we might not be able to launch all those threads to reduce each slice in one go.
# that's why each threads also loops across their inputs, processing multiple values
# so that we can span the entire reduction dimension using a single item group.
# XXX: can we query the 1024?
kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(1024), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(shuffle), R′, A)
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R′, A)
pipeline = MTLComputePipelineState(kernel.fun.device, kernel.fun)

# how many threads do we want?
Expand All @@ -196,8 +208,11 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# even though we can always reduce each slice in a single item group, that may not be
# optimal as it might not saturate the GPU. we already launch some groups to process
# independent dimensions in parallel; pad that number to ensure full occupancy.
#
# also, make sure the grain size is not too high so as to starve threads of work.
other_groups = length(Rother)
reduce_groups = cld(length(Rreduce), reduce_threads)
grain = min(grain, prevpow(2, cld(length(Rreduce), reduce_threads)))
reduce_groups = cld(length(Rreduce), reduce_threads * grain)

# determine the launch configuration
threads = reduce_threads
Expand All @@ -207,7 +222,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
if reduce_groups == 1
# we can cover the dimensions to reduce using a single group
@metal threads=threads grid=groups partial_mapreduce_device(
f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(shuffle), R′, A)
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R′, A)
else
# we need multiple steps to cover all values to reduce
partial = similar(R, (size(R)..., reduce_groups))
Expand All @@ -221,7 +237,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
end
end
@metal threads=threads grid=groups partial_mapreduce_device(
f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(shuffle), partial, A)
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)

GPUArrays.mapreducedim!(identity, op, R′, partial; init=init)
end
Expand Down

0 comments on commit af6f7c4

Please sign in to comment.