diff --git a/README.md b/README.md index 4aba76c..369f189 100644 --- a/README.md +++ b/README.md @@ -55,9 +55,9 @@ With this package, we can apply `@fused` to reduce the number of reads and prese ```julia import MultiBroadcastFusion as MBF -# `@fused` calls `Base.copyto!(::MBF.FusedMultiBroadcast)` -# So we must define `copyto!` for a subtype of `AbstractFusedMultiBroadcast` -function Base.copyto!(fmb::MBF.FusedMultiBroadcast) +MBF.@make_fused FusedMultiBroadcast fused +# Now, `@fused` will call `Base.copyto!(::FusedMultiBroadcast)`. Let's define it: +function Base.copyto!(fmb::FusedMultiBroadcast) pairs = fmb.pairs destinations = map(x->x.first, pairs) @inbounds for i in eachindex(destinations) @@ -75,7 +75,7 @@ y1 = rand(3,3) y2 = rand(3,3) # 4 reads, 2 writes -MBF.@fused begin +@fused begin @. y1 = x1 * x2 + x3 * x4 @. y2 = x1 * x3 + x2 * x4 end diff --git a/perf/flame.jl b/perf/flame.jl index db9b3ab..c8d3fea 100644 --- a/perf/flame.jl +++ b/perf/flame.jl @@ -17,7 +17,7 @@ Y = get_arrays(:y, arr_size, AType) function perf_kernel_fused!(X, Y) (; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X (; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y - MBF.@fused begin + @fused begin @. y1 = x1 + x2 + x3 + x4 @. y2 = x2 + x3 + x4 + x5 @. y3 = x3 + x4 + x5 + x6 diff --git a/src/MultiBroadcastFusion.jl b/src/MultiBroadcastFusion.jl index df0b1df..646c1de 100644 --- a/src/MultiBroadcastFusion.jl +++ b/src/MultiBroadcastFusion.jl @@ -1,29 +1,5 @@ module MultiBroadcastFusion - -abstract type AbstractFusedMultiBroadcast end - -""" - FusedMultiBroadcast(pairs::Tuple) - -A mult-broadcast fusion object -""" -struct FusedMultiBroadcast{T} <: AbstractFusedMultiBroadcast - pairs::T -end - -# Base.@propagate_inbounds function rcopyto_at!(pair::Pair, i::CartesianIndex) -# dest,src = pair.first, pair.second -# @inbounds src_i = src[i] -# @inbounds dest[i] = src_i -# return nothing -# end -# Base.@propagate_inbounds function rcopyto_at!(pair::Pair, i::Int) -# dest,src = pair.first, pair.second -# @inbounds src_i = src[i] -# @inbounds dest[i] = src_i -# return nothing -# end Base.@propagate_inbounds function rcopyto_at!(pair::Pair, i...) dest, src = pair.first, pair.second @inbounds dest[i...] = src[i...] @@ -39,4 +15,50 @@ Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, i...) = include("macro_utils.jl") +""" + @make_fused type_name fused_named + +This macro + - Imports MultiBroadcastFusion + - Defines a type, `type_name` + - Defines a macro, `@fused_name` + +This allows users to flexibility +to customize their broadcast fusion. + +# Example +```julia +import MultiBroadcastFusion as MBF +MBF.@make_fused MyFusedBroadcast my_fused + +Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!") + +x1 = rand(3,3) +y1 = rand(3,3) +y2 = rand(3,3) + +# 4 reads, 2 writes +@my_fused begin + @. y1 = x1 + @. y2 = x1 +end +``` +""" +macro make_fused(type_name, fused_name) + t = esc(type_name) + f = esc(fused_name) + return quote + struct $t{T <: Tuple} + pairs::T + end + macro $f(expr) + _pairs = esc($(fused_pairs)(expr)) + t = $t + quote + Base.copyto!($t($_pairs)) + end + end + end +end + end # module MultiBroadcastFusion diff --git a/src/macro_utils.jl b/src/macro_utils.jl index 7afc3e4..733b755 100644 --- a/src/macro_utils.jl +++ b/src/macro_utils.jl @@ -5,27 +5,20 @@ function materialize_args(expr::Expr) return (expr.args[2], expr.args[3]) end -function fused(expr) end - -macro fused(expr) - _pairs = gensym() - quote - $_pairs = $(esc(fused_pairs(expr))) - Base.copyto!(FusedMultiBroadcast($_pairs)) - end -end - macro fused_pairs(expr) esc(fused_pairs(expr)) end function _fused_pairs(expr::Expr) - @assert expr.head == :block + # @assert expr.head == :block exprs_out = [] for _expr in expr.args # TODO: should we retain LineNumberNode? + if _expr isa Symbol # ???????? + return "" + end _expr isa LineNumberNode && continue - @assert _expr isa Expr + # @assert _expr isa Expr if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__") se = code_lowered_single_expression(_expr) margs = materialize_args(se) @@ -41,16 +34,6 @@ end fused_pairs(expr::Expr) = Meta.parse(_fused_pairs(expr)) -macro fused_multibroadcast(expr) - esc(fused_multibroadcast("MultiBroadcastFusion.FusedMultiBroadcast", expr)) -end - -macro fused_multibroadcast(fmb, expr) - esc(fused_multibroadcast(fmb, expr)) -end -fused_multibroadcast(fmb, expr::Expr) = - Meta.parse("$(fmb)($(_fused_pairs(expr)))") - function build_expr(s::String, code_remain) n_subs = count("%", s) if n_subs > 0 diff --git a/test/bm_fused_reads_vs_hard_coded.jl b/test/bm_fused_reads_vs_hard_coded.jl index 1ea4d22..173190e 100644 --- a/test/bm_fused_reads_vs_hard_coded.jl +++ b/test/bm_fused_reads_vs_hard_coded.jl @@ -48,6 +48,28 @@ function knl_multi_copyto_hard_coded!(X, Y, ::Val{nitems}) where {nitems} return nothing end +function knl_multi_copyto_hard_coded!(X, Y, ::Val{nitems}) where {nitems} + (; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X + (; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y + gidx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + if gidx < nitems + @inbounds begin + idx = gidx + y1[idx] = x1[idx] + x6[idx] + y2[idx] = x2[idx] + x7[idx] + y3[idx] = x3[idx] + x8[idx] + y4[idx] = x4[idx] + x9[idx] + y5[idx] = x5[idx] + x10[idx] + y6[idx] = y1[idx] + y7[idx] = y2[idx] + y8[idx] = y3[idx] + y9[idx] = y4[idx] + y10[idx] = y5[idx] + end + end + return nothing +end + # =========================================== has_cuda = CUDA.has_cuda() @@ -75,7 +97,7 @@ end function perf_kernel_fused!(X, Y) (; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X (; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y - MBF.@fused begin + @fused begin @. y1 = x1 + x2 + x3 + x4 @. y2 = x2 + x3 + x4 + x5 @. y3 = x3 + x4 + x5 + x6 diff --git a/test/bm_fused_shared_reads.jl b/test/bm_fused_shared_reads.jl index 004382a..a742899 100644 --- a/test/bm_fused_shared_reads.jl +++ b/test/bm_fused_shared_reads.jl @@ -1,5 +1,5 @@ #= -using Revise; include(joinpath("test", "fused.jl")) +using Revise; include(joinpath("test", "bm_fused_shared_reads.jl")) =# include("utils.jl") @@ -19,7 +19,7 @@ end function perf_kernel_shared_reads_fused!(X, Y) (; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X (; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y - MBF.@fused begin + @fused begin @. y1 = x1 + x2 + x3 + x4 @. y2 = x2 + x3 + x4 + x5 @. y3 = x3 + x4 + x5 + x6 diff --git a/test/bm_fused_shared_reads_writes.jl b/test/bm_fused_shared_reads_writes.jl index f1b7c80..a58f9a7 100644 --- a/test/bm_fused_shared_reads_writes.jl +++ b/test/bm_fused_shared_reads_writes.jl @@ -23,7 +23,7 @@ end function perf_kernel_shared_reads_writes_fused!(X, Y) (; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X (; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y - MBF.@fused begin + @fused begin @. y1 = x1 + x6 @. y2 = x2 + x7 @. y3 = x3 + x8 diff --git a/test/expr_fused_pairs.jl b/test/expr_fused_pairs.jl index 5d3d63d..06c73d5 100644 --- a/test/expr_fused_pairs.jl +++ b/test/expr_fused_pairs.jl @@ -15,16 +15,3 @@ import MultiBroadcastFusion as MBF )) @test MBF.fused_pairs(expr_in) == expr_out end - -@testset "fused_multibroadcast" begin - expr_in = quote - @. y1 = x1 + x2 + x3 + x4 - @. y2 = x2 + x3 + x4 + x5 - end - - expr_out = :(MultiBroadcastFusion.FusedMultiBroadcast(( - Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)), - Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)), - ))) - @test MBF.fused_multibroadcast(MBF.FusedMultiBroadcast, expr_in) == expr_out -end diff --git a/test/utils.jl b/test/utils.jl index d3d8b5e..a31f4eb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -22,10 +22,11 @@ device(x) = GPU() # in every file without getting warnings. To silence the warnings, # we've wrapped this in an if-statement. # WARNING: -# If you've updated `Base.copyto!(fmb::MBF.FusedMultiBroadcast)`, +# If you've updated `Base.copyto!(fmb::FusedMultiBroadcast)`, # then Revise will not update this method!!! -if !hasmethod(Base.copyto!, Tuple{<:MBF.FusedMultiBroadcast}) - function Base.copyto!(fmb::MBF.FusedMultiBroadcast) +MBF.@make_fused FusedMultiBroadcast fused +if !hasmethod(Base.copyto!, Tuple{<:FusedMultiBroadcast}) + function Base.copyto!(fmb::FusedMultiBroadcast) pairs = fmb.pairs dest = first(pairs).first @assert device(dest) isa CPU || device(dest) isa GPU @@ -45,7 +46,7 @@ if !hasmethod(Base.copyto!, Tuple{<:MBF.FusedMultiBroadcast}) end function copyto_cpu!(pairs::T, ei::EI) where {T, EI} - @inbounds for i in ei + @inbounds @simd ivdep for i in ei MBF.rcopyto_at!(pairs, i) end return nothing