Skip to content

Commit

Permalink
Add new rewrite that does not make strong assumptions on result type (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Nov 21, 2022
1 parent 5c4d464 commit c715a0c
Show file tree
Hide file tree
Showing 5 changed files with 612 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ function isequal_canonical(x::_SparseMat, y::_SparseMat)
end

include("rewrite.jl")
include("rewrite_generic.jl")
include("dispatch.jl")

# Test that can be used to test an implementation of the interface
Expand Down
90 changes: 59 additions & 31 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,23 @@
# one at http://mozilla.org/MPL/2.0/.

"""
@rewrite(expr)
@rewrite(expr; move_factors_into_sums = false)
Return the value of `expr` exploiting the mutability of the temporary
Return the value of `expr`, exploiting the mutability of the temporary
expressions created for the computation of the result.
## Examples
If you have an `Expr` as input, use [`rewrite_and_return`](@ref) instead.
The expression
```julia
MA.@rewrite(x + y * z + u * v * w)
```
is rewritten into
```julia
MA.add_mul!!(
MA.add_mul!!(
MA.copy_if_mutable(x),
y, z),
u, v, w)
```
See [`rewrite`](@ref) for an explanation of the keyword argument.
"""
macro rewrite(expr)
return rewrite_and_return(expr)
macro rewrite(args...)
@assert 1 <= length(args) <= 2
if length(args) == 1
return rewrite_and_return(args[1]; move_factors_into_sums = true)
end
@assert Meta.isexpr(args[2], :(=), 2) &&
args[2].args[1] == :move_factors_into_sums
return rewrite_and_return(args[1]; move_factors_into_sums = args[2].args[2])
end

struct Zero end
Expand Down Expand Up @@ -268,31 +263,64 @@ function _is_decomposable_with_factors(ex)
end

"""
rewrite(x)
rewrite(expr; move_factors_into_sums::Bool = true) -> Tuple{Symbol,Expr}
Rewrites the expression `expr` to use mutable arithmetics.
Returns `(variable, code)` comprised of a `gensym`'d variable equivalent to
`expr` and the code necessary to create the variable.
## `move_factors_into_sums`
If `move_factors_into_sums = true`, some terms are rewritten based on the
assumption that summations produce a linear function.
For example, if `move_factors_into_sums = true`, then
`y * sum(x[i] for i in 1:2)` is rewritten to:
```julia
variable = MA.Zero()
for i in 1:2
variable = MA.operate!!(MA.add_mul, result, y, x[i])
end
```
If `move_factors_into_sums = false`, it is rewritten to:
```julia
term = MA.Zero()
for i in 1:2
term = MA.operate!!(MA.add_mul, term, x[i])
end
variable = MA.operate!!(*, y, term)
```
Rewrite the expression `x` as specified in [`@rewrite`](@ref).
Returns a variable name as `Symbol` and the rewritten expression assigning the
value of the expression `x` to the variable.
The latter can produce an additional allocation if there is an efficient
fallback for `add_mul` and not for `*(y, term)`.
"""
function rewrite(x)
function rewrite(x; kwargs...)
variable = gensym()
code = rewrite_and_return(x)
code = rewrite_and_return(x; kwargs...)
return variable, :($variable = $code)
end

"""
rewrite_and_return(x)
rewrite_and_return(expr; move_factors_into_sums::Bool = true) -> Expr
Rewrite the expression `expr` using mutable arithmetics and return an expression
in which the last statement is equivalent to `expr`.
Rewrite the expression `x` as specified in [`@rewrite`](@ref).
Return the rewritten expression returning the result.
See [`rewrite`](@ref) for an explanation of the keyword argument.
"""
function rewrite_and_return(x)
output_variable, code = _rewrite(false, false, x, nothing, [], [])
# We need to use `let` because `rewrite(:(sum(i for i in 1:2))`
function rewrite_and_return(expr; move_factors_into_sums::Bool = true)
if move_factors_into_sums
root, stack = _rewrite(false, false, expr, nothing, [], [])
else
stack = quote end
root, _ = _rewrite_generic(stack, expr)
end
return quote
let
$code
$output_variable
$stack
$root
end
end
end
Expand Down
Loading

0 comments on commit c715a0c

Please sign in to comment.