diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 1d84d78b4..e87885111 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -90,13 +90,14 @@ Base.@propagate_inbounds _map_getindex(args::Tuple{}, I) = () # Reduce an array across the grid. All elements to be processed can be addressed by the # product of the two iterators `Rreduce` and `Rother`, where the latter iterator will have # singleton entries for the dimensions that should be reduced (and vice versa). -function partial_mapreduce_device(f, op, neutral, maxthreads, Rreduce, Rother, shuffle, R, As...) +function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce}, + ::Val{Rother}, ::Val{Rlen}, shuffle, R, As...) where {Rreduce, Rother, Rlen} # decompose the 1D hardware indices into separate ones for reduction (across items # and possibly groups if it doesn't fit) and other elements (remaining groups) localIdx_reduce = thread_position_in_threadgroup_1d() localDim_reduce = threads_per_threadgroup_1d() - groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid_1d(), length(Rother)) - groupDim_reduce = threadgroups_per_grid_1d() ÷ length(Rother) + groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid_1d(), Rlen) + groupDim_reduce = threadgroups_per_grid_1d() ÷ Rlen # group-based indexing into the values outside of the reduction dimension # (that means we can safely synchronize items within this group) @@ -171,8 +172,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, # that's why each threads also loops across their inputs, processing multiple values # so that we can span the entire reduction dimension using a single item group. # XXX: can we query the 1024? - kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(1024), Rreduce, - Rother, Val(shuffle), R′, A) + kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(1024), Val(Rreduce), Val(Rother), + Val(UInt64(length(Rother))), Val(shuffle), R′, A) pipeline = MtlComputePipelineState(kernel.fun.lib.device, kernel.fun) # how many threads do we want? @@ -206,7 +207,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, if reduce_groups == 1 # we can cover the dimensions to reduce using a single group @metal threads=threads grid=groups partial_mapreduce_device( - f, op, init, Val(threads), Rreduce, Rother, Val(shuffle), R′, A) + f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(shuffle), R′, A) else # we need multiple steps to cover all values to reduce partial = similar(R, (size(R)..., reduce_groups)) @@ -220,7 +221,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, end end @metal threads=threads grid=groups partial_mapreduce_device( - f, op, init, Val(threads), Rreduce, Rother, Val(shuffle), partial, A) + f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(shuffle), partial, A) GPUArrays.mapreducedim!(identity, op, R′, partial; init=init) end