-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4852fe5
commit 77dbc0b
Showing
11 changed files
with
256 additions
and
235 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,83 +1,82 @@ | ||
|
||
function materialize_args(expr::Expr) | ||
@assert expr.head == :call | ||
@assert expr.args[1] == :(Base.materialize!) | ||
return (expr.args[2], expr.args[3]) | ||
@assert expr.head == :call | ||
@assert expr.args[1] == :(Base.materialize!) | ||
return (expr.args[2], expr.args[3]) | ||
end | ||
|
||
function fused(expr) | ||
end | ||
function fused(expr) end | ||
|
||
macro fused(expr) | ||
_pairs = gensym() | ||
quote | ||
$_pairs = $(esc(fused_pairs(expr))) | ||
Base.copyto!(FusedMultiBroadcast($_pairs)) | ||
end | ||
_pairs = gensym() | ||
quote | ||
$_pairs = $(esc(fused_pairs(expr))) | ||
Base.copyto!(FusedMultiBroadcast($_pairs)) | ||
end | ||
end | ||
|
||
macro fused_pairs(expr) | ||
esc(fused_pairs(expr)) | ||
esc(fused_pairs(expr)) | ||
end | ||
|
||
function _fused_pairs(expr::Expr) | ||
@assert expr.head == :block | ||
exprs_out = [] | ||
for _expr in expr.args | ||
# TODO: should we retain LineNumberNode? | ||
_expr isa LineNumberNode && continue | ||
@assert _expr isa Expr | ||
if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__") | ||
se = code_lowered_single_expression(_expr) | ||
margs = materialize_args(se) | ||
push!(exprs_out, :(Pair($(margs[1]), $(margs[2])))) | ||
end | ||
end | ||
if length(exprs_out) == 1 | ||
return "($(exprs_out[1]),)" | ||
else | ||
return "("*join(exprs_out, ",")*")" | ||
end | ||
@assert expr.head == :block | ||
exprs_out = [] | ||
for _expr in expr.args | ||
# TODO: should we retain LineNumberNode? | ||
_expr isa LineNumberNode && continue | ||
@assert _expr isa Expr | ||
if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__") | ||
se = code_lowered_single_expression(_expr) | ||
margs = materialize_args(se) | ||
push!(exprs_out, :(Pair($(margs[1]), $(margs[2])))) | ||
end | ||
end | ||
if length(exprs_out) == 1 | ||
return "($(exprs_out[1]),)" | ||
else | ||
return "(" * join(exprs_out, ",") * ")" | ||
end | ||
end | ||
|
||
fused_pairs(expr::Expr) = Meta.parse(_fused_pairs(expr)) | ||
|
||
macro fused_multibroadcast(expr) | ||
esc(fused_multibroadcast("MultiBroadcastFusion.FusedMultiBroadcast", expr)) | ||
esc(fused_multibroadcast("MultiBroadcastFusion.FusedMultiBroadcast", expr)) | ||
end | ||
|
||
macro fused_multibroadcast(fmb, expr) | ||
esc(fused_multibroadcast(fmb, expr)) | ||
esc(fused_multibroadcast(fmb, expr)) | ||
end | ||
fused_multibroadcast(fmb, expr::Expr) = | ||
Meta.parse("$(fmb)($(_fused_pairs(expr)))") | ||
Meta.parse("$(fmb)($(_fused_pairs(expr)))") | ||
|
||
function build_expr(s::String, code_remain) | ||
n_subs = count("%", s) | ||
if n_subs > 0 | ||
while n_subs > 0 | ||
regex = r"%[0-9]" | ||
m = match(regex, s) | ||
smatch = m.match | ||
j = Meta.parse(smatch[2:end]) | ||
s = replace(s, smatch => string(code_remain[j])) | ||
n_subs = count("%", s) | ||
end | ||
else | ||
return s | ||
end | ||
return s | ||
n_subs = count("%", s) | ||
if n_subs > 0 | ||
while n_subs > 0 | ||
regex = r"%[0-9]" | ||
m = match(regex, s) | ||
smatch = m.match | ||
j = Meta.parse(smatch[2:end]) | ||
s = replace(s, smatch => string(code_remain[j])) | ||
n_subs = count("%", s) | ||
end | ||
else | ||
return s | ||
end | ||
return s | ||
end | ||
|
||
build_expr(code::Vector) = build_expr(string(code[end]), code) | ||
|
||
function code_lowered_single_expression(expr) | ||
code_lowered = Base.Meta.lower(Main, expr) | ||
code_info = code_lowered.args[1] | ||
code = code_info.code # vector | ||
s = build_expr(code) | ||
if startswith(s, "return ") | ||
s = replace(s, "return " => "") | ||
end | ||
return Base.Meta.parse(s) | ||
code_lowered = Base.Meta.lower(Main, expr) | ||
code_info = code_lowered.args[1] | ||
code = code_info.code # vector | ||
s = build_expr(code) | ||
if startswith(s, "return ") | ||
s = replace(s, "return " => "") | ||
end | ||
return Base.Meta.parse(s) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.