From 4a1b14a3394f1eca3f71d57d853f19172637d673 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 10 Oct 2024 09:49:38 -0400 Subject: [PATCH] Apply a few fixes --- ext/MultiBroadcastFusionCUDAExt.jl | 37 +++++++++++++++++------------- src/execution/fused_kernels.jl | 32 ++++++++++++++++++++------ test/runtests.jl | 6 +++++ 3 files changed, 52 insertions(+), 23 deletions(-) diff --git a/ext/MultiBroadcastFusionCUDAExt.jl b/ext/MultiBroadcastFusionCUDAExt.jl index 68459de..225746f 100644 --- a/ext/MultiBroadcastFusionCUDAExt.jl +++ b/ext/MultiBroadcastFusionCUDAExt.jl @@ -10,28 +10,33 @@ function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA) (; pairs) = fmb dest = first(pairs).first destinations = map(p -> p.first, pairs) - nitems = length(parent(dest)) - max_threads = 256 # can be higher if conditions permit - nthreads = min(max_threads, nitems) - nblocks = cld(nitems, nthreads) - a1 = axes(dest) all(a -> axes(a) == axes(dest), destinations) || error("Cannot fuse broadcast expressions with unequal broadcast axes") + nitems = length(parent(dest)) CI = CartesianIndices(axes(dest)) - CUDA.@cuda threads = (nthreads) blocks = (nblocks) fused_copyto_kernel!( - fmb, - CI, - ) - return nothing + kernel = + CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!( + fmb, + CI, + ) + config = CUDA.launch_configuration(kernel.fun) + threads = min(nitems, config.threads) + blocks = cld(nitems, threads) + kernel(fmb, CI; threads, blocks) + return destinations end import Base.Broadcast function fused_copyto_kernel!(fmb::MBF.FusedMultiBroadcast, CI) - (; pairs) = fmb - dest = first(pairs).first - nitems = length(dest) - idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x - if idx ≤ nitems - MBF.rcopyto_at!(pairs, CI[idx]) + @inbounds begin + (; pairs) = fmb + dest = first(pairs).first + nitems = length(dest) + idx = + CUDA.threadIdx().x + + (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + if 1 ≤ idx ≤ nitems + MBF.rcopyto_at!(pairs, CI[idx]) + end end return nothing end diff --git a/src/execution/fused_kernels.jl b/src/execution/fused_kernels.jl index c0d2fcc..67608c6 100644 --- a/src/execution/fused_kernels.jl +++ b/src/execution/fused_kernels.jl @@ -7,23 +7,40 @@ struct MBF_CUDA end device(x::AbstractArray) = MBF_CPU() function Base.copyto!(fmb::FusedMultiBroadcast) - pairs = fmb.pairs # (Pair(dest1, bc1),Pair(dest2, bc2),...) + # Since we intercept Base.copyto!, we have not yet + # called Base.Broadcast.instantiate (as this is done + # in materialize, which has been stripped away), so, + # let's call it here. + fmb′ = FusedMultiBroadcast( + map(fmb.pairs) do p + Pair(p.first, Base.Broadcast.instantiate(p.second)) + end, + ) + (; pairs) = fmb′ # (Pair(dest1, bc1),Pair(dest2, bc2),...) dest = first(pairs).first - fused_copyto!(fmb, device(dest)) + fused_copyto!(fmb′, device(dest)) end -Base.@propagate_inbounds function rcopyto_at!(pair::Pair, i...) +Base.@propagate_inbounds function rcopyto_at!( + pair::Pair, + i::Vararg{T}, +) where {T} dest, src = pair.first, pair.second @inbounds dest[i...] = src[i...] return nothing end -Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, i...) +Base.@propagate_inbounds function rcopyto_at!( + pairs::Tuple, + i::Vararg{T}, +) where {T} rcopyto_at!(first(pairs), i...) rcopyto_at!(Base.tail(pairs), i...) end -Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, i...) = - rcopyto_at!(first(pairs), i...) -@inline rcopyto_at!(pairs::Tuple{}, i...) = nothing +Base.@propagate_inbounds rcopyto_at!( + pairs::Tuple{<:Any}, + i::Vararg{T}, +) where {T} = rcopyto_at!(first(pairs), i...) +@inline rcopyto_at!(pairs::Tuple{}, i::Vararg{T}) where {T} = nothing # This is better than the baseline. function fused_copyto!(fmb::FusedMultiBroadcast, ::MBF_CPU) @@ -39,6 +56,7 @@ function fused_copyto!(fmb::FusedMultiBroadcast, ::MBF_CPU) dest[i] = bc[i] end end + return destinations end diff --git a/test/runtests.jl b/test/runtests.jl index 8828b4b..500464d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,10 @@ #= +julia --project +using TestEnv +TestEnv.activate() +using CUDA; +ENV["PERFORM_BENCHMARK"]="true"; + using Revise; include(joinpath("test", "collection", "runtests.jl")) using Revise; include(joinpath("test", "execution", "runtests.jl")) using Revise; include(joinpath("test", "runtests.jl"))