From 0cdedc7dda874a831785ae9f5440dcbb73fb4043 Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 14 Nov 2022 13:39:51 +1300 Subject: [PATCH] Place new rewrite behind opt-in kwarg --- docs/src/index.md | 4 - src/MutableArithmetics.jl | 2 +- src/rewrite.jl | 54 ++++--- src/{new_rewrite.jl => rewrite_generic.jl} | 148 ++++++++------------ test/{new_rewrite.jl => rewrite_generic.jl} | 58 +++++--- test/runtests.jl | 2 +- 6 files changed, 131 insertions(+), 137 deletions(-) rename src/{new_rewrite.jl => rewrite_generic.jl} (68%) rename test/{new_rewrite.jl => rewrite_generic.jl} (75%) diff --git a/docs/src/index.md b/docs/src/index.md index 028d6655..1efeb0e4 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -13,7 +13,3 @@ end ```@autodocs Modules = [MutableArithmetics] ``` - -```@autodocs -Modules = [MutableArithmetics.MutableArithmetics2] -``` diff --git a/src/MutableArithmetics.jl b/src/MutableArithmetics.jl index 9c30017e..652f93ca 100644 --- a/src/MutableArithmetics.jl +++ b/src/MutableArithmetics.jl @@ -176,8 +176,8 @@ function isequal_canonical(x::_SparseMat, y::_SparseMat) end include("rewrite.jl") +include("rewrite_generic.jl") include("dispatch.jl") -include("new_rewrite.jl") # Test that can be used to test an implementation of the interface include("Test/Test.jl") diff --git a/src/rewrite.jl b/src/rewrite.jl index 62e355a3..7dec90ca 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -5,9 +5,9 @@ # one at http://mozilla.org/MPL/2.0/. """ - @rewrite(expr) + @rewrite(expr; assume_sums_are_linear = false) -Return the value of `expr` exploiting the mutability of the temporary +Return the value of `expr`, exploiting the mutability of the temporary expressions created for the computation of the result. ## Examples @@ -21,12 +21,23 @@ is rewritten into MA.add_mul!!( MA.add_mul!!( MA.copy_if_mutable(x), - y, z), - u, v, w) + y, + z, + ), + u, + v, + w, +) ``` """ -macro rewrite(expr) - return rewrite_and_return(expr) +macro rewrite(args...) + @assert 1 <= length(args) <= 2 + if length(args) == 1 + return rewrite_and_return(args[1]; assume_sums_are_linear = true) + end + @assert Meta.isexpr(args[2], :(=), 2) && + args[2].args[1] == :assume_sums_are_linear + return rewrite_and_return(args[1]; assume_sums_are_linear = args[2].args[2]) end struct Zero end @@ -268,31 +279,36 @@ function _is_decomposable_with_factors(ex) end """ - rewrite(x) + rewrite(expr; assume_sums_are_linear::Bool = true) -> Tuple{Symbol,Expr} + +Rewrites the expression `expr` to use mutable arithmetics. -Rewrite the expression `x` as specified in [`@rewrite`](@ref). -Returns a variable name as `Symbol` and the rewritten expression assigning the -value of the expression `x` to the variable. +Returns `(variable, code)` comprised of a `gensym`'d variable equivalent to +`expr` and the code necessary to create the variable. """ -function rewrite(x) +function rewrite(x; kwargs...) variable = gensym() - code = rewrite_and_return(x) + code = rewrite_and_return(x; kwargs...) return variable, :($variable = $code) end """ - rewrite_and_return(x) + rewrite_and_return(expr; assume_sums_are_linear::Bool = true) -> Expr -Rewrite the expression `x` as specified in [`@rewrite`](@ref). +Rewrite the expression `expr` as specified in [`@rewrite`](@ref). Return the rewritten expression returning the result. """ -function rewrite_and_return(x) - output_variable, code = _rewrite(false, false, x, nothing, [], []) - # We need to use `let` because `rewrite(:(sum(i for i in 1:2))` +function rewrite_and_return(expr; assume_sums_are_linear::Bool = true) + if assume_sums_are_linear + root, stack = _rewrite(false, false, expr, nothing, [], []) + else + stack = quote end + root, _ = _rewrite_generic(stack, expr) + end return quote let - $code - $output_variable + $stack + $root end end end diff --git a/src/new_rewrite.jl b/src/rewrite_generic.jl similarity index 68% rename from src/new_rewrite.jl rename to src/rewrite_generic.jl index 2e0a18c0..672f3ee3 100644 --- a/src/new_rewrite.jl +++ b/src/rewrite_generic.jl @@ -4,58 +4,12 @@ # v.2.0. If a copy of the MPL was not distributed with this file, You can obtain # one at http://mozilla.org/MPL/2.0/. -module MutableArithmetics2 - -import ..MutableArithmetics - -const MA = MutableArithmetics - # We need these two methods because we're changing how * is re-written. -MA.operate!(::typeof(*), x::AbstractArray{T}, y::T) where {T} = (x .*= y) -MA.operate!(::typeof(*), x::AbstractArray, y) = (x .= MA.operate(*, x, y)) - -""" - @rewrite(expr) - -Rewrites the expression `expr` to use mutable arithmetics. +operate!(::typeof(*), x::AbstractArray{T}, y::T) where {T} = (x .*= y) +operate!(::typeof(*), x::AbstractArray, y) = (x .= operate(*, x, y)) -For a non-macro version, see [`rewrite_and_return`](@ref). """ -macro rewrite(expr) - return rewrite_and_return(expr) -end - -""" - rewrite_and_return(expr) -> Expr - -Rewrites the expression `expr` to use mutable arithmetics. -""" -function rewrite_and_return(expr) - stack = quote end - root, _ = _rewrite(stack, expr) - return quote - let - $stack - $root - end - end -end - -""" - rewrite(expr) -> Tuple{Symbol,Expr} - -Rewrites the expression `expr` to use mutable arithmetics. Returns -`(variable, code)` comprised of a `gensym`'d variable equivalent to `expr` and -the code necessary to create the variable. -""" -function rewrite(expr) - variable = gensym() - code = rewrite_and_return(expr) - return variable, :($variable = $code) -end - -""" - _rewrite(stack::Expr, expr::T)::Tuple{Any,Bool} + _rewrite_generic(stack::Expr, expr::T)::Tuple{Any,Bool} This method is the heart of the rewrite logic. It converts `expr` into a mutable equivalent, places any intermediate calculations onto `stack`, and returns a @@ -63,17 +17,18 @@ tuple containing the return value---which is either `expr` or a `gensym`ed variable equivalent to `expr`---and a boolean flag that indicates whether the return value can be mutated by future callers. """ -function _rewrite end +function _rewrite_generic end + """ - _rewrite(::Expr, x) + _rewrite_generic(::Expr, x) A generic fallback. Given a type `x` we return it without mutation. In addition, this type should not be mutated by future callers. """ -_rewrite(::Expr, x) = esc(x), false +_rewrite_generic(::Expr, x) = esc(x), false """ - _rewrite(::Expr, x::Number) + _rewrite_generic(::Expr, x::Number) If `x` is a `Number` at macro expansion time, it _must_ be a constant literal. We return `x` without mutation, but we return `true` because other callers may @@ -82,15 +37,15 @@ in `copy_if_mutable(x)` before using it as the first argument to `operate!!`. This most commonly happens in situations like `x^2`. """ -_rewrite(::Expr, x::Number) = x, true +_rewrite_generic(::Expr, x::Number) = x, true """ - _rewrite(stack::Expr, expr::Expr) + _rewrite_generic(stack::Expr, expr::Expr) This method is the heart of the rewrite logic. It converts `expr` into a mutable equivalent. """ -function _rewrite(stack::Expr, expr::Expr) +function _rewrite_generic(stack::Expr, expr::Expr) if !Meta.isexpr(expr, :call) # In situations like `x[i]`, we do not attempt to rewrite. Return `expr` # and don't let future callers mutate. @@ -105,10 +60,10 @@ function _rewrite(stack::Expr, expr::Expr) # This is a generator expression like `sum(i for i in args)`. Generators # come in two forms: `sum(i for i=I, j=J)` or `sum(i for i=I for j=J)`. # The latter is a `:flatten` expression and needs additional handling, - # but we delay this complexity for _rewrite_generator. + # but we delay this complexity for _rewrite_generic_generator. if expr.args[1] in (:sum, :Σ, :∑) # Summations use :+ as the reduction operator. - return _rewrite_generator(stack, :+, expr.args[2]) + return _rewrite_generic_generator(stack, :+, expr.args[2]) end # We don't know what this is. Return the expression and don't let # future callers mutate. @@ -121,32 +76,32 @@ function _rewrite(stack::Expr, expr::Expr) # +(args...) => add_mul(add_mul(arg1, arg2), arg3) @assert length(expr.args) > 1 if length(expr.args) == 2 # +(arg) - return _rewrite(stack, expr.args[2]) + return _rewrite_generic(stack, expr.args[2]) end - return _rewrite_to_nested_op(stack, expr, MA.add_mul) + return _rewrite_generic_to_nested_op(stack, expr, add_mul) elseif expr.args[1] == :- # -(args...) => sub_mul(sub_mul(arg1, arg2), arg3) @assert length(expr.args) > 1 if length(expr.args) == 2 # -(arg) - return _rewrite(stack, Expr(:call, :*, -1, expr.args[2])) + return _rewrite_generic(stack, Expr(:call, :*, -1, expr.args[2])) end - return _rewrite_to_nested_op(stack, expr, MA.sub_mul) + return _rewrite_generic_to_nested_op(stack, expr, sub_mul) elseif expr.args[1] == :* # *(args...) => *(*(arg1, arg2), arg3) @assert length(expr.args) > 2 - arg1, is_mutable = _rewrite(stack, expr.args[2]) - arg2, _ = _rewrite(stack, expr.args[3]) + arg1, is_mutable = _rewrite_generic(stack, expr.args[2]) + arg2, _ = _rewrite_generic(stack, expr.args[3]) rhs = if is_mutable - Expr(:call, MA.operate!!, *, arg1, arg2) + Expr(:call, operate!!, *, arg1, arg2) else Expr(:call, *, arg1, arg2) end root = gensym() push!(stack.args, :($root = $rhs)) for i in 4:length(expr.args) - arg, _ = _rewrite(stack, expr.args[i]) + arg, _ = _rewrite_generic(stack, expr.args[i]) rhs = if is_mutable - Expr(:call, MA.operate!!, *, root, arg) + Expr(:call, operate!!, *, root, arg) else Expr(:call, *, root, arg) end @@ -158,21 +113,31 @@ function _rewrite(stack::Expr, expr::Expr) # .+(args...) => add_mul.(add_mul.(arg1, arg2), arg3) @assert length(expr.args) > 1 if length(expr.args) == 2 # +(arg) - return _rewrite(stack, expr.args[2]) + return _rewrite_generic(stack, expr.args[2]) end - return _rewrite_to_nested_op(stack, expr, MA.add_mul; broadcast = true) + return _rewrite_generic_to_nested_op( + stack, + expr, + add_mul; + broadcast = true, + ) elseif expr.args[1] == :.- # .-(args...) => sub_mul.(sub_mul.(arg1, arg2), arg3) @assert length(expr.args) > 1 if length(expr.args) == 2 # .-(arg) - return _rewrite(stack, Expr(:call, :.*, -1, expr.args[2])) + return _rewrite_generic(stack, Expr(:call, :.*, -1, expr.args[2])) end - return _rewrite_to_nested_op(stack, expr, MA.sub_mul; broadcast = true) + return _rewrite_generic_to_nested_op( + stack, + expr, + sub_mul; + broadcast = true, + ) else # Use the non-mutating call. result = Expr(:call, esc(expr.args[1])) for i in 2:length(expr.args) - arg, _ = _rewrite(stack, expr.args[i]) + arg, _ = _rewrite_generic(stack, expr.args[i]) push!(result.args, arg) end root = gensym() @@ -183,20 +148,20 @@ function _rewrite(stack::Expr, expr::Expr) end end -function _rewrite_to_nested_op(stack, expr, op; broadcast::Bool = false) - root, is_mutable = _rewrite(stack, expr.args[2]) +function _rewrite_generic_to_nested_op(stack, expr, op; broadcast::Bool = false) + root, is_mutable = _rewrite_generic(stack, expr.args[2]) if !is_mutable # The first argument isn't mutable, so we need to make a copy. - arg = Expr(:call, MA.copy_if_mutable, root) + arg = Expr(:call, copy_if_mutable, root) root = gensym() push!(stack.args, Expr(:(=), root, arg)) end for i in 3:length(expr.args) - arg, _ = _rewrite(stack, expr.args[i]) + arg, _ = _rewrite_generic(stack, expr.args[i]) rhs = if broadcast - Expr(:call, MA.broadcast!!, op, root, arg) + Expr(:call, broadcast!!, op, root, arg) else - Expr(:call, MA.operate!!, op, root, arg) + Expr(:call, operate!!, op, root, arg) end root = gensym() push!(stack.args, Expr(:(=), root, rhs)) @@ -207,13 +172,18 @@ end _is_call(expr, op) = Meta.isexpr(expr, :call) && expr.args[1] == op """ - _rewrite_generator(stack::Expr, op::Symbol, expr::Expr) + _rewrite_generic_generator(stack::Expr, op::Symbol, expr::Expr) Special handling for generator expressions. `op` is `:+` and `expr` is a `:generator` or `:flatten` expression. """ -function _rewrite_generator(stack::Expr, op::Symbol, expr::Expr, root = nothing) +function _rewrite_generic_generator( + stack::Expr, + op::Symbol, + expr::Expr, + root = nothing, +) @assert op == :+ is_flatten = Meta.isexpr(expr, :flatten) if is_flatten @@ -222,7 +192,7 @@ function _rewrite_generator(stack::Expr, op::Symbol, expr::Expr, root = nothing) # The value we're going to mutate. Start it off at `Zero`. if root === nothing root = gensym() - push!(stack.args, Expr(:(=), root, MA.Zero())) + push!(stack.args, Expr(:(=), root, Zero())) end # We need a new stack to go inside our for-loops since we want to # recursively rewrite the inner part as well. @@ -231,29 +201,29 @@ function _rewrite_generator(stack::Expr, op::Symbol, expr::Expr, root = nothing) # Optimization time! Instead of operate!!(op, root, op(args...)), # rewrite as operate!!(op, root, arg) for arg in args for arg in expr.args[1].args[2:end] - value, _ = _rewrite(new_stack, arg) - rhs = Expr(:call, MA.operate!!, MA.add_mul, root, value) + value, _ = _rewrite_generic(new_stack, arg) + rhs = Expr(:call, operate!!, add_mul, root, value) push!(new_stack.args, :($root = $rhs)) end elseif op == :+ && _is_call(expr.args[1], :*) # Optimization time! Instead of operate!!(+, root, *(args...)), rewrite # this as operate!!(add_mul, root, args...) - rhs = Expr(:call, MA.operate!!, MA.add_mul, root) + rhs = Expr(:call, operate!!, add_mul, root) for arg in expr.args[1].args[2:end] - value, _ = _rewrite(new_stack, arg) + value, _ = _rewrite_generic(new_stack, arg) push!(rhs.args, value) end push!(new_stack.args, :($root = $rhs)) elseif is_flatten # The first argument is itself a generator - _rewrite_generator(new_stack, op, expr.args[1], root) + _rewrite_generic_generator(new_stack, op, expr.args[1], root) else # expr.args[1] is the inner part of the loop. Rewrite it. We don't care # if it is mutable because we need a new value every iteration. - inner, _ = _rewrite(new_stack, expr.args[1]) + inner, _ = _rewrite_generic(new_stack, expr.args[1]) # Now build up the summation or product part of the inner loop. It's # always safe to mutate because we're going to start with `root=Zero()`. - rhs = Expr(:call, MA.operate!!, MA.add_mul, root, inner) + rhs = Expr(:call, operate!!, add_mul, root, inner) push!(new_stack.args, :($root = $rhs)) end # This is a little complicated: walk back out of the generator statements @@ -289,5 +259,3 @@ function _iterable_condition(new_stack, expr) end return body end - -end # module diff --git a/test/new_rewrite.jl b/test/rewrite_generic.jl similarity index 75% rename from test/new_rewrite.jl rename to test/rewrite_generic.jl index ea60ce64..3a076fc6 100644 --- a/test/new_rewrite.jl +++ b/test/rewrite_generic.jl @@ -4,14 +4,13 @@ # v.2.0. If a copy of the MPL was not distributed with this file, You can obtain # one at http://mozilla.org/MPL/2.0/. -module TestMutableArithmetics2 +module TestRewriteGeneric using Test import MutableArithmetics const MA = MutableArithmetics -const MA2 = MA.MutableArithmetics2 function runtests() for name in names(@__MODULE__; all = true) @@ -26,12 +25,17 @@ end macro test_rewrite(expr) return quote - esc(@test MA.isequal_canonical(MA2.@rewrite($expr), $expr)) + esc( + @test MA.isequal_canonical( + MA.@rewrite($expr, assume_sums_are_linear = false), + $expr, + ) + ) end end function test_rewrite() - x, expr = MA2.rewrite(1 + 1) + x, expr = MA.rewrite(1 + 1, assume_sums_are_linear = false) @test x isa Symbol @test Meta.isexpr(expr, :(=), 2) return @@ -52,7 +56,7 @@ end function test_rewrite_not_call() x = [1, 2, 3] for i in 1:3 - @test MA2.@rewrite(x[i]) == i + @test MA.@rewrite(x[i], assume_sums_are_linear = false) == i end return end @@ -66,10 +70,10 @@ function test_rewrite_sum_to_add_mul() @test_rewrite +1 @test_rewrite +(-2) x = [1.2] - @test MA2.@rewrite(+x) == x - @test MA2.@rewrite(x + x) == 2 * x - @test MA2.@rewrite(+(x, x, x)) == 3 * x - @test MA2.@rewrite(+(x, x, x, x)) == 4 * x + @test MA.@rewrite(+x, assume_sums_are_linear = false) == x + @test MA.@rewrite(x + x, assume_sums_are_linear = false) == 2 * x + @test MA.@rewrite(+(x, x, x), assume_sums_are_linear = false) == 3 * x + @test MA.@rewrite(+(x, x, x, x), assume_sums_are_linear = false) == 4 * x return end @@ -78,17 +82,17 @@ function test_rewrite_prod_to_add_mul() @test_rewrite 2 * -2 A = [1.0 2.0; 3.0 4.0] x = [5.0, 6.0] - @test MA2.@rewrite(A * x) == A * x + @test MA.@rewrite(A * x, assume_sums_are_linear = false) == A * x return end function test_rewrite_nonconcrete_vector() x = [5.0, 6.0] y = Vector{Union{Float64,String}}(x) - @test MA2.@rewrite(x' * y) == x' * y - @test MA2.@rewrite(x .+ y) == x .+ y + @test MA.@rewrite(x' * y, assume_sums_are_linear = false) == x' * y + @test MA.@rewrite(x .+ y, assume_sums_are_linear = false) == x .+ y # Reproducing buggy behavior in MA.@rewrite. - @test_broken MA2.@rewrite(x + y) == x + x + @test_broken MA.@rewrite(x + y, assume_sums_are_linear = false) == x + x return end @@ -97,15 +101,16 @@ function test_rewrite_minus_to_add_mul() @test_rewrite -(+2) @test_rewrite -(-2) x = [1.2] - @test MA2.@rewrite(-x) == -x - @test MA2.@rewrite(x - x) == [0.0] - @test MA2.@rewrite(-(x, x, x)) == -1 * x - @test MA2.@rewrite(-(x, x, x, x)) == -2 * x + @test MA.@rewrite(-x, assume_sums_are_linear = false) == -x + @test MA.@rewrite(x - x, assume_sums_are_linear = false) == [0.0] + @test MA.@rewrite(-(x, x, x), assume_sums_are_linear = false) == -1 * x + @test MA.@rewrite(-(x, x, x, x), assume_sums_are_linear = false) == -2 * x return end function test_rewrite_sum() - @test MA2.@rewrite(sum(i for i in 1:0)) == MA.Zero() + @test MA.@rewrite(sum(i for i in 1:0), assume_sums_are_linear = false) == + MA.Zero() @test_rewrite sum(i for i in 1:10) @test_rewrite sum(i + i^2 for i in 1:10) @test_rewrite sum(i * i for i in 1:10) @@ -134,8 +139,14 @@ function test_rewrite_generator() @test_rewrite sum(i + j^2 for i in 1:2 for j in 2:3 for k in 3:4) # Generators with dependent variables # This syntax is unsupported by Julia! - @test MA2.@rewrite(sum(i + j^2 for i in 1:2, j in i:3)) == 34 - @test MA2.@rewrite(sum(i + j^2 for i in 1:2, j in 2:3, k in i:j)) == 68 + @test MA.@rewrite( + sum(i + j^2 for i in 1:2, j in i:3), + assume_sums_are_linear = false, + ) == 34 + @test MA.@rewrite( + sum(i + j^2 for i in 1:2, j in 2:3, k in i:j), + assume_sums_are_linear = false, + ) == 68 # Unnivariate generators with an if statement @test_rewrite sum(i for i in 1:2 if i >= 1) @test_rewrite sum(j^2 for j in 2:3 if j <= 3) @@ -169,7 +180,10 @@ function test_rewrite_linear_algebra() x = [1, 2] A = [1 2; 3 4] y = reshape(x, (1, length(x))) * A * x .- 1 - @test MA2.@rewrite(reshape(x, (1, length(x))) * A * x .- 1) == y + @test MA.@rewrite( + reshape(x, (1, length(x))) * A * x .- 1, + assume_sums_are_linear = false, + ) == y return end @@ -205,4 +219,4 @@ end end # module -TestMutableArithmetics2.runtests() +TestRewriteGeneric.runtests() diff --git a/test/runtests.jl b/test/runtests.jl index 90a953c8..a1c52859 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,7 +26,7 @@ end include("matmul.jl") include("dispatch.jl") include("rewrite.jl") -include("new_rewrite.jl") +include("rewrite_generic.jl") # It is easy to introduce macro scoping issues into MutableArithmetics, # particularly ones that rely on `MA` or `MutableArithmetics` being present in