Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support .= syntax #26

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/collection/fused_assemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
margs = materialize_args(se)
subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2]))))
subexpr
elseif e.head == Symbol(".=")
se = code_lowered_single_expression(e)
margs = materialize_args(se)
subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2]))))
subexpr

Check warning on line 19 in src/collection/fused_assemble.jl

View check run for this annotation

Codecov / codecov/patch

src/collection/fused_assemble.jl#L19

Added line #L19 was not covered by tests
else
Expr(
transform_assemble(e.head, sym),
Expand Down Expand Up @@ -90,7 +95,7 @@
arg isa LineNumberNode && continue
s_error = if arg isa QuoteNode
"Dangling symbols are not allowed inside fused blocks"
elseif arg.head == :call
elseif arg.head == :call && !(isa_dot_op(arg[1]))
"Function calls are not allowed inside fused blocks"
elseif arg.head == :(=)
"Non-broadcast assignments are not allowed inside fused blocks"
Expand All @@ -109,6 +114,13 @@
elseif arg.head == :if
check_restrictions(arg.args[2])
elseif arg.head == :macrocall && arg.args[1] == Symbol("@inbounds")
elseif arg.head == :call && isa_dot_op(arg.args[1])
# Allows for :(a .+ foo(b))
# where foo(b) could be a getter to an array.
# This technically opens the door to incorrectness,
# as foo could change the pointer of `b` to something else
# however, this seems unlikely.
elseif isa_dot_op(arg.head) # dot function call
else
@show dump(arg)
error("Uncaught edge case")
Expand Down
14 changes: 13 additions & 1 deletion src/collection/fused_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
margs = materialize_args(se)
subexpr = :(Pair($(margs[1]), $(margs[2])))
subexpr
elseif e.head == Symbol(".=")
se = code_lowered_single_expression(e)
margs = materialize_args(se)
subexpr = :(Pair($(margs[1]), $(margs[2])))
subexpr

Check warning on line 19 in src/collection/fused_direct.jl

View check run for this annotation

Codecov / codecov/patch

src/collection/fused_direct.jl#L19

Added line #L19 was not covered by tests
else
Expr(transform(e.head), transform.(e.args)...)
end
Expand Down Expand Up @@ -82,7 +87,7 @@
"Loops are not allowed inside fused blocks"
elseif _expr.head == :if
"If-statements are not allowed inside fused blocks"
elseif _expr.head == :call
elseif _expr.head == :call && !(isa_dot_op(_expr.args[1]))
"Function calls are not allowed inside fused blocks"
elseif _expr.head == :(=)
"Non-broadcast assignments are not allowed inside fused blocks"
Expand All @@ -95,6 +100,13 @@
end
isempty(s_error) || error(s_error)
if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__")
elseif _expr.head == :call && isa_dot_op(_expr.args[1])
# Allows for :(a .+ foo(b))
# where foo(b) could be a getter to an array.
# This technically opens the door to incorrectness,
# as foo could change the pointer of `b` to something else
# however, this seems unlikely.
elseif isa_dot_op(_expr.head) # dot function call
else
@show dump(_expr)
error("Uncaught edge case")
Expand Down
28 changes: 26 additions & 2 deletions src/collection/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@

function materialize_args(expr::Expr)
@assert expr.head == :call
@assert expr.args[1] == :(Base.materialize!)
return (expr.args[2], expr.args[3])
if expr.args[1] == :(Base.materialize!)
return (expr.args[2], expr.args[3])
elseif expr.args[1] == :(Base.materialize)
return (expr.args[2], expr.args[2])

Check warning on line 37 in src/collection/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/collection/utils.jl#L36-L37

Added lines #L36 - L37 were not covered by tests
else
error("Uncaught edge case.")

Check warning on line 39 in src/collection/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/collection/utils.jl#L39

Added line #L39 was not covered by tests
end
end

const dot_ops = (
Symbol(".+"),
Symbol(".-"),
Symbol(".*"),
Symbol("./"),
Symbol(".="),
Symbol(".=="),
Symbol(".≠"),
Symbol(".^"),
Symbol(".!="),
Symbol(".>"),
Symbol(".<"),
Symbol(".>="),
Symbol(".<="),
Symbol(".≤"),
Symbol(".≥"),
)
isa_dot_op(op) = any(x -> op == x, dot_ops)
19 changes: 19 additions & 0 deletions test/collection/expr_fused_assemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ import MultiBroadcastFusion as MBF
@test MBF.fused_assemble(expr_in, :tup) == expr_out
end

#! format: off
@testset "fused_assemble - simple sequential, explicit dots" begin
expr_in = quote
y1 .= x1 .+ x2 .+ x3 .+ x4
y2 .= x2 .+ x3 .+ x4 .+ x5
end

expr_out = quote
tup = ()
tup = (tup..., Pair(y1, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x1, x2), x3), x4)))
tup = (tup..., Pair(y2, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x2, x3), x4), x5)))
tup
end

@test MBF.linefilter!(MBF.fused_assemble(expr_in, :tup)) ==
MBF.linefilter!(expr_out)
@test MBF.fused_assemble(expr_in, :tup) == expr_out
end
#! format: on

@testset "fused_assemble - loop" begin
expr_in = quote
Expand Down
15 changes: 15 additions & 0 deletions test/collection/expr_fused_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,18 @@ import MultiBroadcastFusion as MBF
))
@test MBF.fused_direct(expr_in) == expr_out
end

#! format: off
@testset "fused_direct - explicit dots" begin
expr_in = quote
y1 .= x1 .+ x2 .+ x3 .+ x4
y2 .= x2 .+ x3 .+ x4 .+ x5
end

expr_out = :(tuple(
Pair(y1, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x1, x2), x3), x4)),
Pair(y2, Base.broadcasted(+, Base.broadcasted(+, Base.broadcasted(+, x2, x3), x4), x5)),
))
@test MBF.fused_direct(expr_in) == expr_out
end
#! format: on
Loading