diff --git a/src/collection/fused_assemble.jl b/src/collection/fused_assemble.jl index a053e85..1d3c319 100644 --- a/src/collection/fused_assemble.jl +++ b/src/collection/fused_assemble.jl @@ -12,6 +12,11 @@ function transform_assemble(e::Expr, sym) margs = materialize_args(se) subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2])))) subexpr + elseif e.head == Symbol(".=") + se = code_lowered_single_expression(e) + margs = materialize_args(se) + subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2])))) + subexpr else Expr( transform_assemble(e.head, sym), @@ -90,7 +95,7 @@ function check_restrictions_assemble(expr::Expr) arg isa LineNumberNode && continue s_error = if arg isa QuoteNode "Dangling symbols are not allowed inside fused blocks" - elseif arg.head == :call + elseif arg.head == :call && !(isa_dot_op(arg[1])) "Function calls are not allowed inside fused blocks" elseif arg.head == :(=) "Non-broadcast assignments are not allowed inside fused blocks" @@ -109,6 +114,13 @@ function check_restrictions_assemble(expr::Expr) elseif arg.head == :if check_restrictions(arg.args[2]) elseif arg.head == :macrocall && arg.args[1] == Symbol("@inbounds") + elseif arg.head == :call && isa_dot_op(arg.args[1]) + # Allows for :(a .+ foo(b)) + # where foo(b) could be a getter to an array. + # This technically opens the door to incorrectness, + # as foo could change the pointer of `b` to something else + # however, this seems unlikely. + elseif isa_dot_op(arg.head) # dot function call else @show dump(arg) error("Uncaught edge case") diff --git a/src/collection/fused_direct.jl b/src/collection/fused_direct.jl index e1449bd..ab7a478 100644 --- a/src/collection/fused_direct.jl +++ b/src/collection/fused_direct.jl @@ -12,6 +12,11 @@ function transform(e::Expr) margs = materialize_args(se) subexpr = :(Pair($(margs[1]), $(margs[2]))) subexpr + elseif e.head == Symbol(".=") + se = code_lowered_single_expression(e) + margs = materialize_args(se) + subexpr = :(Pair($(margs[1]), $(margs[2]))) + subexpr else Expr(transform(e.head), transform.(e.args)...) end @@ -82,7 +87,7 @@ function check_restrictions(expr::Expr) "Loops are not allowed inside fused blocks" elseif _expr.head == :if "If-statements are not allowed inside fused blocks" - elseif _expr.head == :call + elseif _expr.head == :call && !(isa_dot_op(_expr.args[1])) "Function calls are not allowed inside fused blocks" elseif _expr.head == :(=) "Non-broadcast assignments are not allowed inside fused blocks" @@ -95,6 +100,13 @@ function check_restrictions(expr::Expr) end isempty(s_error) || error(s_error) if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__") + elseif _expr.head == :call && isa_dot_op(_expr.args[1]) + # Allows for :(a .+ foo(b)) + # where foo(b) could be a getter to an array. + # This technically opens the door to incorrectness, + # as foo could change the pointer of `b` to something else + # however, this seems unlikely. + elseif isa_dot_op(_expr.head) # dot function call else @show dump(_expr) error("Uncaught edge case") diff --git a/src/collection/utils.jl b/src/collection/utils.jl index 7221382..a241352 100644 --- a/src/collection/utils.jl +++ b/src/collection/utils.jl @@ -31,6 +31,30 @@ end function materialize_args(expr::Expr) @assert expr.head == :call - @assert expr.args[1] == :(Base.materialize!) - return (expr.args[2], expr.args[3]) + if expr.args[1] == :(Base.materialize!) + return (expr.args[2], expr.args[3]) + elseif expr.args[1] == :(Base.materialize) + return (expr.args[2], expr.args[2]) + else + error("Uncaught edge case.") + end end + +const dot_ops = ( + Symbol(".+"), + Symbol(".-"), + Symbol(".*"), + Symbol("./"), + Symbol(".="), + Symbol(".=="), + Symbol(".≠"), + Symbol(".^"), + Symbol(".!="), + Symbol(".>"), + Symbol(".<"), + Symbol(".>="), + Symbol(".<="), + Symbol(".≤"), + Symbol(".≥"), +) +isa_dot_op(op) = any(x -> op == x, dot_ops) diff --git a/test/collection/expr_fused_assemble.jl b/test/collection/expr_fused_assemble.jl index fc7fd8c..e2e407f 100644 --- a/test/collection/expr_fused_assemble.jl +++ b/test/collection/expr_fused_assemble.jl @@ -22,6 +22,25 @@ import MultiBroadcastFusion as MBF @test MBF.fused_assemble(expr_in, :tup) == expr_out end +#! format: off +@testset "fused_assemble - simple sequential, explicit dots" begin + expr_in = quote + y1 .= x1 .+ x2 .+ x3 .+ x4 + y2 .= x2 .+ x3 .+ x4 .+ x5 + end + + expr_out = quote + tup = () + tup = (tup..., Pair(y1, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x1, x2), x3), x4))) + tup = (tup..., Pair(y2, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x2, x3), x4), x5))) + tup + end + + @test MBF.linefilter!(MBF.fused_assemble(expr_in, :tup)) == + MBF.linefilter!(expr_out) + @test MBF.fused_assemble(expr_in, :tup) == expr_out +end +#! format: on @testset "fused_assemble - loop" begin expr_in = quote diff --git a/test/collection/expr_fused_direct.jl b/test/collection/expr_fused_direct.jl index 63d2d4c..efb0261 100644 --- a/test/collection/expr_fused_direct.jl +++ b/test/collection/expr_fused_direct.jl @@ -16,3 +16,18 @@ import MultiBroadcastFusion as MBF )) @test MBF.fused_direct(expr_in) == expr_out end + +#! format: off +@testset "fused_direct - explicit dots" begin + expr_in = quote + y1 .= x1 .+ x2 .+ x3 .+ x4 + y2 .= x2 .+ x3 .+ x4 .+ x5 + end + + expr_out = :(tuple( + Pair(y1, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x1, x2), x3), x4)), + Pair(y2, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x2, x3), x4), x5)), + )) + @test MBF.fused_direct(expr_in) == expr_out +end +#! format: on