diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 146841893..32f94cf7e 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -238,4 +238,7 @@ end export inverse, left_inverse, right_inverse, @register_inverse, has_inverse, has_left_inverse, has_right_inverse include("inverse.jl") +export rootfunction, left_continuous_function, right_continuous_function, @register_discontinuity +include("discontinuities.jl") + end # module diff --git a/src/discontinuities.jl b/src/discontinuities.jl new file mode 100644 index 000000000..c5114b57c --- /dev/null +++ b/src/discontinuities.jl @@ -0,0 +1,106 @@ +""" + rootfunction(f) + +Given a function `f` with a discontinuity or discontinuous derivative, return the rootfinding +function of `f`. The rootfinding function `g` takes the same arguments as `f`, and is such +that `f` can be described as a piecewise function based on the sign of `g`, where each piece +is continuous and has a continuous derivative. The pieces are obtained using +`left_continuous_function(f)` and `right_continuous_function(f)`. + +More formally, +```julia +f(args...) = if g(args...) < 0 + left_continuous_function(f)(args...) +else + right_continuous_function(f)(args...) +end +``` + +For example, if `f` is `max(x, y)`, the root function is `(x, y) -> x - y` with +`left_continuous_function` as `(x, y) -> y` and `right_continuous_function` as +`(x, y) -> x`. + +See also: [`left_continuous_function`](@ref), [`right_continuous_function`](@ref). +""" +function rootfunction end + +""" + left_continuous_function(f) + +Given a function `f` with a discontinuity or discontinuous derivative, return a function +taking the same arguments as `f` which is continuous and has a continuous derivative +when `rootfinding_function(f)` is negative. + +See also: [`rootfunction`](@ref). +""" +function left_continuous_function end + +""" + right_continuous_function(f) + +Given a function `f` with a discontinuity or discontinuous derivative, return a function +taking the same arguments as `f` which is continuous and has a continuous derivative +when `rootfinding_function(f)` is positive. + +See also: [`rootfunction`](@ref). +""" +function right_continuous_function end + +""" + @register_discontinuity f(arg1, arg2, ...) root_expr left_expr right_expr + +Utility macro to register functions with discontinuities. The function `f` with +arguments `arg1, arg2, ...` has a `rootfunction` of `root_expr`, a +`left_continuous_function` of `left_expr` and `right_continuous_function` of +`right_expr`. `root_expr`, `left_expr` and `right_expr` are all expressions in terms +of `arg1, arg2, ...`. + +For example, `max(x, y)` can be registered as `@register_discontinuity max(x, y) x - y y x`. + +See also: [`rootfunction`](@ref) +""" +macro register_discontinuity(f, root, left, right) + Meta.isexpr(f, :call) || error("Expected function call as first argument") + args = f.args[2:end] + fn = esc(f.args[1]) + rootname = gensym(:root) + rootfn = :(function $rootname($(args...)) + $root + end) + leftname = gensym(:left) + leftfn = :(function $leftname($(args...)) + $left + end) + rightname = gensym(:right) + rightfn = :(function $rightname($(args...)) + $right + end) + return quote + $rootfn + (::$typeof($rootfunction))(::$typeof($fn)) = $rootname + $leftfn + (::$typeof($left_continuous_function))(::$typeof($fn)) = $leftname + $rightfn + (::$typeof($right_continuous_function))(::$typeof($fn)) = $rightname + end +end + +# a triangle function which is zero when x is a multiple of period +function _triangle(x, period) + x /= 2period + abs(x + 1 // 4 - floor(x + 3 // 4)) - 1 // 2 +end + +@register_discontinuity abs(x) x -x x +# just needs a rootfind to hit the discontinuity +@register_discontinuity mod(x, y) _triangle(x, y) mod(x, y) mod(x, y) +@register_discontinuity rem(x, y) _triangle(x, y) rem(x, y) rem(x, y) +@register_discontinuity div(x, y) _triangle(x, y) div(x, y) div(x, y) +@register_discontinuity max(x, y) x - y y x +@register_discontinuity min(x, y) x - y x y +@register_discontinuity NaNMath.max(x, y) x - y y x +@register_discontinuity NaNMath.min(x, y) x - y x y +@register_discontinuity <(x, y) x - y true false +@register_discontinuity <=(x, y) y - x false true +@register_discontinuity >(x, y) y - x true false +@register_discontinuity >=(x, y) x - y false true diff --git a/src/inverse.jl b/src/inverse.jl index 3fb87a46a..fcc9ede65 100644 --- a/src/inverse.jl +++ b/src/inverse.jl @@ -47,24 +47,26 @@ inverse. """ macro register_inverse(f, g, dir::QuoteNode = :(:both)) dir = dir.value + f = esc(f) + g = esc(g) if dir == :both quote - (::typeof($inverse))(::typeof($f)) = $g - (::typeof($inverse))(::typeof($g)) = $f - (::typeof($left_inverse))(::typeof($f)) = $(inverse)($f) - (::typeof($right_inverse))(::typeof($f)) = $(inverse)($f) - (::typeof($left_inverse))(::typeof($g)) = $(inverse)($g) - (::typeof($right_inverse))(::typeof($g)) = $(inverse)($g) + (::$typeof($inverse))(::$typeof($f)) = $g + (::$typeof($inverse))(::$typeof($g)) = $f + (::$typeof($left_inverse))(::$typeof($f)) = $(inverse)($f) + (::$typeof($right_inverse))(::$typeof($f)) = $(inverse)($f) + (::$typeof($left_inverse))(::$typeof($g)) = $(inverse)($g) + (::$typeof($right_inverse))(::$typeof($g)) = $(inverse)($g) end elseif dir == :left quote - (::typeof($left_inverse))(::typeof($f)) = $g - (::typeof($right_inverse))(::typeof($g)) = $f + (::$typeof($left_inverse))(::$typeof($f)) = $g + (::$typeof($right_inverse))(::$typeof($g)) = $f end elseif dir == :right quote - (::typeof($right_inverse))(::typeof($f)) = $g - (::typeof($left_inverse))(::typeof($g)) = $f + (::$typeof($right_inverse))(::$typeof($f)) = $g + (::$typeof($left_inverse))(::$typeof($g)) = $f end else throw(ArgumentError("The third argument to `@register_inverse` must be `left` or `right`")) diff --git a/test/discontinuities.jl b/test/discontinuities.jl new file mode 100644 index 000000000..868f54f49 --- /dev/null +++ b/test/discontinuities.jl @@ -0,0 +1,30 @@ +using Symbolics, NaNMath, Test + +function discontinuity_eval(fn, args...) + if rootfunction(fn)(args...) < 0 + left_continuous_function(fn)(args...) + else + right_continuous_function(fn)(args...) + end +end + +@testset "abs" begin + for x in -1.0:0.001:1.0 + @test abs(x) ≈ discontinuity_eval(abs, x) + end +end + +@testset "$(nameof(f))" for f in (mod, rem, div) + y = 0.7 + for x in -2y:0.001:2y + @test f(x, y) ≈ discontinuity_eval(f, x, y) + end +end + +@testset "$(nameof(f))" for f in (min, max, NaNMath.min, NaNMath.max, <, <=, >, >=) + for x in 0.0:0.1:1.0 + for y in 0.0:0.1:1.0 + @test f(x, y) ≈ discontinuity_eval(f, x, y) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b42a8cae4..e9c7a6b64 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "RootFinding solver" begin include("solver.jl") end @safetestset "Function inverses test" begin include("inverse.jl") end @safetestset "Taylor Series Test" begin include("taylor.jl") end + @safetestset "Discontinuity registration test" begin include("discontinuities.jl") end end end