diff --git a/Project.toml b/Project.toml index 89980e7..ad03a6d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MultiBroadcastFusion" uuid = "c3c07f87-98de-43f2-a76f-835b330b2cbb" authors = ["CliMA Contributors "] -version = "0.1.0" +version = "0.1.1" [compat] julia = "^1.10" diff --git a/src/MultiBroadcastFusion.jl b/src/MultiBroadcastFusion.jl index 646c1de..80620a9 100644 --- a/src/MultiBroadcastFusion.jl +++ b/src/MultiBroadcastFusion.jl @@ -13,7 +13,10 @@ Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, i...) = rcopyto_at!(first(pairs), i...) @inline rcopyto_at!(pairs::Tuple{}, i...) = nothing -include("macro_utils.jl") +include("utils.jl") +include("code_lowered_single_expression.jl") +include("fused_pairs.jl") +include("fused_pairs_flexible.jl") """ @make_fused type_name fused_named @@ -61,4 +64,50 @@ macro make_fused(type_name, fused_name) end end +""" + @make_fused_flexible type_name fused_named + +This macro + - Imports MultiBroadcastFusion + - Defines a type, `type_name` + - Defines a macro, `@fused_name` + +This allows users to flexibility +to customize their broadcast fusion. + +# Example +```julia +import MultiBroadcastFusion as MBF +MBF.@make_fused_flexible MyFusedBroadcast my_fused + +Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!") + +x1 = rand(3,3) +y1 = rand(3,3) +y2 = rand(3,3) + +# 4 reads, 2 writes +@my_fused begin + @. y1 = x1 + @. y2 = x1 +end +``` +""" +macro make_fused_flexible(type_name, fused_name) + t = esc(type_name) + f = esc(fused_name) + return quote + struct $t{T <: Tuple} + pairs::T + end + macro $f(expr) + _pairs = esc($(fused_pairs_flexible)(expr, gensym())) + t = $t + quote + Base.copyto!($t($_pairs)) + end + end + end +end + end # module MultiBroadcastFusion diff --git a/src/code_lowered_single_expression.jl b/src/code_lowered_single_expression.jl new file mode 100644 index 0000000..d4a9df6 --- /dev/null +++ b/src/code_lowered_single_expression.jl @@ -0,0 +1,15 @@ +# General case: do nothing (identity) +substitute(x, code) = x +substitute(x::Core.SSAValue, code) = substitute(code[x.id], code) +substitute(x::Core.ReturnNode, code) = substitute(code[x.val.id], code) +substitute(s::Symbol, code) = s +# Expression: recursively substitute for Expr +substitute(e::Expr, code) = + Expr(substitute(e.head, code), substitute.(e.args, Ref(code))...) + +code_info(expr) = Base.Meta.lower(Main, expr).args[1] +function code_lowered_single_expression(expr) + code = code_info(expr).code # vector + s = string(substitute(code[end], code)) + return Base.Meta.parse(s) +end diff --git a/src/fused_pairs.jl b/src/fused_pairs.jl new file mode 100644 index 0000000..b8886f3 --- /dev/null +++ b/src/fused_pairs.jl @@ -0,0 +1,59 @@ +##### +##### Simple version +##### + +# General case: do nothing (identity) +transform(x) = x +transform(x::Core.SSAValue) = transform(code[x.id]) +transform(x::Core.ReturnNode) = transform(code[x.val.id]) +transform(s::Symbol) = s +# Expression: recursively transform for Expr +function transform(e::Expr) + if e.head == :macrocall && e.args[1] == Symbol("@__dot__") + se = code_lowered_single_expression(e) + margs = materialize_args(se) + subexpr = :(Pair($(margs[1]), $(margs[2]))) + subexpr + else + Expr(transform(e.head), transform.(e.args)...) + end +end + +function fused_pairs(expr::Expr) + check_restrictions(expr) + e = transform(expr) + @assert e.head == :block + ex = Expr(:call, :tuple, e.args...) + # Filter out LineNumberNode, as this will not be valid due to prepending `tup = ()` + linefilter!(ex) + ex +end + +function check_restrictions(expr::Expr) + for _expr in expr.args + _expr isa LineNumberNode && continue + s_error = if _expr isa QuoteNode + "Dangling symbols are not allowed inside fused blocks" + elseif _expr.head == :for + "Loops are not allowed inside fused blocks" + elseif _expr.head == :if + "If-statements are not allowed inside fused blocks" + elseif _expr.head == :call + "Function calls are not allowed inside fused blocks" + elseif _expr.head == :(=) + "Non-broadcast assignments are not allowed inside fused blocks" + elseif _expr.head == :let + "Let-blocks are not allowed inside fused blocks" + elseif _expr.head == :quote + "Quotes are not allowed inside fused blocks" + else + "" + end + isempty(s_error) || error(s_error) + if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__") + else + @show dump(_expr) + error("Uncaught edge case") + end + end +end diff --git a/src/fused_pairs_flexible.jl b/src/fused_pairs_flexible.jl new file mode 100644 index 0000000..7018cba --- /dev/null +++ b/src/fused_pairs_flexible.jl @@ -0,0 +1,62 @@ +##### +##### Complex/flexible version +##### + +# General case: do nothing (identity) +transform_flex(x, sym) = x +transform_flex(x::Core.SSAValue, sym) = transform_flex(code[x.id], sym) +transform_flex(x::Core.ReturnNode, sym) = transform_flex(code[x.val.id], sym) +transform_flex(s::Symbol, sym) = s +# Expression: recursively transform_flex for Expr +function transform_flex(e::Expr, sym) + if e.head == :macrocall && e.args[1] == Symbol("@__dot__") + se = code_lowered_single_expression(e) + margs = materialize_args(se) + subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2])))) + subexpr + else + Expr(transform_flex(e.head, sym), transform_flex.(e.args, sym)...) + end +end + +function fused_pairs_flexible(expr::Expr, sym::Symbol) + check_restrictions_flexible(expr) + e = transform_flex(expr, sym) + @assert e.head == :block + ex = Expr(:block, :($sym = ()), e.args..., sym) + # Filter out LineNumberNode, as this will not be valid due to prepending `tup = ()` + linefilter!(ex) + ex +end + +function check_restrictions_flexible(expr::Expr) + for arg in expr.args + arg isa LineNumberNode && continue + s_error = if arg isa QuoteNode + "Dangling symbols are not allowed inside fused blocks" + elseif arg.head == :call + "Function calls are not allowed inside fused blocks" + elseif arg.head == :(=) + "Non-broadcast assignments are not allowed inside fused blocks" + elseif arg.head == :let + "Let-blocks are not allowed inside fused blocks" + elseif arg.head == :quote + "Quotes are not allowed inside fused blocks" + else + "" + end + isempty(s_error) || error(s_error) + + if arg.head == :macrocall && arg.args[1] == Symbol("@__dot__") + elseif arg.head == :for + check_restrictions(arg.args[2]) + elseif arg.head == :if + check_restrictions(arg.args[2]) + elseif arg.head == :macrocall && arg.args[1] == Symbol("@inbounds") + else + @show dump(arg) + error("Uncaught edge case") + end + end + return nothing +end diff --git a/src/macro_utils.jl b/src/macro_utils.jl deleted file mode 100644 index 54eded2..0000000 --- a/src/macro_utils.jl +++ /dev/null @@ -1,72 +0,0 @@ - -function materialize_args(expr::Expr) - @assert expr.head == :call - @assert expr.args[1] == :(Base.materialize!) - return (expr.args[2], expr.args[3]) -end - -macro fused_pairs(expr) - esc(fused_pairs(expr)) -end - -function _fused_pairs(expr::Expr) - # @assert expr.head == :block - exprs_out = [] - for _expr in expr.args - # TODO: should we retain LineNumberNode? - # if _expr isa Symbol # ???????? - # error("???") - # return "" - # end - if _expr isa QuoteNode - error("Dangling symbols are not allowed inside fused blocks") - end - _expr isa LineNumberNode && continue - if _expr.head == :for - error("Loops are not allowed inside fused blocks") - elseif _expr.head == :if - error("If-statements are not allowed inside fused blocks") - elseif _expr.head == :call - error("Function calls are not allowed inside fused blocks") - elseif _expr.head == :(=) - error( - "Non-broadcast assignments are not allowed inside fused blocks", - ) - elseif _expr.head == :let - error("Let-blocks are not allowed inside fused blocks") - elseif _expr.head == :quote - error("Quotes are not allowed inside fused blocks") - end - if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__") - se = code_lowered_single_expression(_expr) - margs = materialize_args(se) - push!(exprs_out, :(Pair($(margs[1]), $(margs[2])))) - else - @show dump(_expr) - error("Uncaught edge case") - end - end - if length(exprs_out) == 1 - return "($(exprs_out[1]),)" - else - return "(" * join(exprs_out, ",") * ")" - end -end - -fused_pairs(expr::Expr) = Meta.parse(_fused_pairs(expr)) - -# General case: do nothing (identity) -substitute(x, code) = x -substitute(x::Core.SSAValue, code) = substitute(code[x.id], code) -substitute(x::Core.ReturnNode, code) = substitute(code[x.val.id], code) -substitute(s::Symbol, code) = s -# Expression: recursively substitute for Expr -substitute(e::Expr, code) = - Expr(substitute(e.head, code), substitute.(e.args, Ref(code))...) - -code_info(expr) = Base.Meta.lower(Main, expr).args[1] -function code_lowered_single_expression(expr) - code = code_info(expr).code # vector - s = string(substitute(code[end], code)) - return Base.Meta.parse(s) -end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..7221382 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,36 @@ +##### +##### Helper +##### + +# Recursively remove LineNumberNode from an `Expr` +@noinline function linefilter!(expr::Expr) + total = length(expr.args) + i = 0 + while i < total + i += 1 + if expr.args[i] |> typeof == Expr + if expr.args[i].head == :line + deleteat!(expr.args, i) + total -= 1 + i -= 1 + else + expr.args[i] = linefilter!(expr.args[i]) + end + elseif expr.args[i] |> typeof == LineNumberNode + if expr.head == :macrocall + expr.args[i] = nothing + else + deleteat!(expr.args, i) + total -= 1 + i -= 1 + end + end + end + return expr +end + +function materialize_args(expr::Expr) + @assert expr.head == :call + @assert expr.args[1] == :(Base.materialize!) + return (expr.args[2], expr.args[3]) +end diff --git a/test/expr_errors_and_edge_cases.jl b/test/expr_errors_and_edge_cases.jl index a9cdb1f..dce47f8 100644 --- a/test/expr_errors_and_edge_cases.jl +++ b/test/expr_errors_and_edge_cases.jl @@ -123,7 +123,7 @@ end @. y2 = x2 + x3 + x4 + x5 end - expr_out = :(( + expr_out = :(tuple( Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)), Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)), )) @@ -132,5 +132,5 @@ end @testset "Empty" begin expr_in = quote end - @test MBF.fused_pairs(expr_in) == :(()) + @test MBF.fused_pairs(expr_in) == :(tuple()) end diff --git a/test/expr_fused_pairs.jl b/test/expr_fused_pairs.jl index 06c73d5..898a946 100644 --- a/test/expr_fused_pairs.jl +++ b/test/expr_fused_pairs.jl @@ -3,15 +3,99 @@ using Revise; include(joinpath("test", "expr_fused_pairs.jl")) =# using Test import MultiBroadcastFusion as MBF + @testset "fused_pairs" begin expr_in = quote @. y1 = x1 + x2 + x3 + x4 @. y2 = x2 + x3 + x4 + x5 end - expr_out = :(( + expr_out = :(tuple( Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)), Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)), )) @test MBF.fused_pairs(expr_in) == expr_out end + +@testset "fused_pairs_flexible - simple sequential" begin + expr_in = quote + @. y1 = x1 + x2 + x3 + x4 + @. y2 = x2 + x3 + x4 + x5 + end + + expr_out = quote + tup = () + tup = (tup..., Pair(y1, Base.broadcasted(+, x1, x2, x3, x4))) + tup = (tup..., Pair(y2, Base.broadcasted(+, x2, x3, x4, x5))) + tup + end + + @test MBF.linefilter!(MBF.fused_pairs_flexible(expr_in, :tup)) == + MBF.linefilter!(expr_out) + @test MBF.fused_pairs_flexible(expr_in, :tup) == expr_out +end + + +@testset "fused_pairs_flexible - loop" begin + expr_in = quote + for i in 1:10 + @. y1 = x1 + x2 + x3 + x4 + @. y2 = x2 + x3 + x4 + x5 + end + end + + expr_out = quote + tup = () + for i in 1:10 + tup = (tup..., Pair(y1, Base.broadcasted(+, x1, x2, x3, x4))) + tup = (tup..., Pair(y2, Base.broadcasted(+, x2, x3, x4, x5))) + end + tup + end + + @test MBF.linefilter!(MBF.fused_pairs_flexible(expr_in, :tup)) == + MBF.linefilter!(expr_out) + @test MBF.fused_pairs_flexible(expr_in, :tup) == expr_out +end + +@testset "fused_pairs_flexible - loop with @inbounds" begin + expr_in = quote + @inbounds for i in 1:10 + @. y1 = x1 + x2 + x3 + x4 + @. y2 = x2 + x3 + x4 + x5 + end + end + + expr_out = quote + tup = () + @inbounds for i in 1:10 + tup = (tup..., Pair(y1, Base.broadcasted(+, x1, x2, x3, x4))) + tup = (tup..., Pair(y2, Base.broadcasted(+, x2, x3, x4, x5))) + end + tup + end + @test MBF.linefilter!(MBF.fused_pairs_flexible(expr_in, :tup)) == + MBF.linefilter!(expr_out) + @test MBF.fused_pairs_flexible(expr_in, :tup) == expr_out +end + +@testset "fused_pairs_flexible - if" begin + expr_in = quote + if a && B || something(x, y, z) + @. y1 = x1 + x2 + x3 + x4 + @. y2 = x2 + x3 + x4 + x5 + end + end + + expr_out = quote + tup = () + if a && B || something(x, y, z) + tup = (tup..., Pair(y1, Base.broadcasted(+, x1, x2, x3, x4))) + tup = (tup..., Pair(y2, Base.broadcasted(+, x2, x3, x4, x5))) + end + tup + end + @test MBF.linefilter!(MBF.fused_pairs_flexible(expr_in, :tup)) == + MBF.linefilter!(expr_out) + @test MBF.fused_pairs_flexible(expr_in, :tup) == expr_out +end