Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for complex differentiable functions #37

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/api.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@

const DEFINED_DIFFRULES = Dict{Tuple{Union{Expr,Symbol},Symbol,Int},Any}()
const DEFINED_COMPLEX_DIFFRULES = Dict{Tuple{Union{Expr,Symbol},Symbol,Int},Any}()

"""
@define_diffrule M.f(x) = :(df_dx(\$x))
@define_diffrule M.f(x, y) = :(df_dx(\$x, \$y)), :(df_dy(\$x, \$y))
@define_complex_diffrule M.f(x) = :(df_dx(\$x))
@define_complex_diffrule M.f(x, y) = :(df_dx(\$x, \$y)), :(df_dy(\$x, \$y))

Define a new differentiation rule for the function `M.f` and the given arguments, which should
Expand All @@ -16,14 +19,18 @@ interpolated wherever they are used on the RHS.

Note that differentiation rules are purely symbolic, so no type annotations should be used.

The complex version @define_complex_diffrule should be used if M.f is complex differentiable.
If not, @define_diffrule should be used instead.

Examples:

@define_diffrule Base.cos(x) = :(-sin(\$x))
@define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))
@define_complex_diffrule Base.cos(x) = :(-sin(\$x))
@define_complex_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))

"""
macro define_diffrule(def)

function _getkeyrule(def)
@assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side"
lhs = def.args[1]
rhs = def.args[2]
Expand All @@ -35,6 +42,20 @@ macro define_diffrule(def)
args = lhs.args[2:end]
rule = Expr(:->, Expr(:tuple, args...), rhs)
key = Expr(:tuple, Expr(:quote, M), Expr(:quote, f), length(args))
return key,rule
end

macro define_complex_diffrule(def)
key,rule = _getkeyrule(def)
return esc(quote
$DiffRules.DEFINED_DIFFRULES[$key] = $rule
$DiffRules.DEFINED_COMPLEX_DIFFRULES[$key] = $rule
$key
end)
end

macro define_diffrule(def)
key,rule = _getkeyrule(def)
return esc(quote
$DiffRules.DEFINED_DIFFRULES[$key] = $rule
$key
Expand All @@ -43,6 +64,7 @@ end

"""
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...)
complex_diffrule(M::Union{Expr,Symbol}, f::Symbol, args...)

Return the derivative expression for `M.f` at the given argument(s), with the argument(s)
interpolated into the returned expression.
Expand All @@ -65,9 +87,11 @@ Examples:
(:(c * (x + 2) ^ (c - 1)), :((x + 2) ^ c * log(x + 2)))
"""
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...) = DEFINED_DIFFRULES[M,f,length(args)](args...)
complex_diffrule(M::Union{Expr,Symbol}, f::Symbol, args...) = DEFINED_COMPLEX_DIFFRULES[M,f,length(args)](args...)

"""
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)
hascomplex_diffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)

Return `true` if a differentiation rule is defined for `M.f` and `arity`, or return `false`
otherwise.
Expand All @@ -92,6 +116,7 @@ Examples:
false
"""
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int) = haskey(DEFINED_DIFFRULES, (M, f, arity))
hascomplex_diffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int) = haskey(DEFINED_COMPLEX_DIFFRULES, (M, f, arity))

"""
diffrules()
Expand All @@ -109,6 +134,7 @@ Examples:

"""
diffrules() = keys(DEFINED_DIFFRULES)
complex_diffrules() = keys(DEFINED_COMPLEX_DIFFRULES)

# For v0.6 and v0.7 compatibility, need to support having the diff rule function enter as a
# `Expr(:quote...)` and a `QuoteNode`. When v0.6 support is dropped, the function will
Expand Down
173 changes: 90 additions & 83 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,79 @@
# unary #
#-------#

@define_diffrule Base.:+(x) = :( 1 )
@define_diffrule Base.:-(x) = :( -1 )
@define_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) )
@define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) )
@define_diffrule Base.abs2(x) = :( $x + $x )
@define_diffrule Base.inv(x) = :( -abs2(inv($x)) )
@define_diffrule Base.log(x) = :( inv($x) )
@define_diffrule Base.log10(x) = :( inv($x) / log(10) )
@define_diffrule Base.log2(x) = :( inv($x) / log(2) )
@define_diffrule Base.log1p(x) = :( inv($x + 1) )
@define_diffrule Base.exp(x) = :( exp($x) )
@define_diffrule Base.exp2(x) = :( exp2($x) * log(2) )
@define_diffrule Base.exp10(x) = :( exp10($x) * log(10) )
@define_diffrule Base.expm1(x) = :( exp($x) )
@define_diffrule Base.sin(x) = :( cos($x) )
@define_diffrule Base.cos(x) = :( -sin($x) )
@define_diffrule Base.tan(x) = :( 1 + tan($x)^2 )
@define_diffrule Base.sec(x) = :( sec($x) * tan($x) )
@define_diffrule Base.csc(x) = :( -csc($x) * cot($x) )
@define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) )
@define_diffrule Base.sind(x) = :( (π / 180) * cosd($x) )
@define_diffrule Base.cosd(x) = :( -(π / 180) * sind($x) )
@define_diffrule Base.tand(x) = :( (π / 180) * (1 + tand($x)^2) )
@define_diffrule Base.secd(x) = :( (π / 180) * secd($x) * tand($x) )
@define_diffrule Base.cscd(x) = :( -(π / 180) * cscd($x) * cotd($x) )
@define_diffrule Base.cotd(x) = :( -(π / 180) * (1 + cotd($x)^2) )
@define_diffrule Base.sinpi(x) = :( π * cospi($x) )
@define_diffrule Base.cospi(x) = :( -π * sinpi($x) )
@define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) )
@define_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) )
@define_diffrule Base.atan(x) = :( inv(1 + $x^2) )
@define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) )
@define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) )
@define_diffrule Base.acot(x) = :( -inv(1 + $x^2) )
@define_diffrule Base.asind(x) = :( 180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.acosd(x) = :( -180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.atand(x) = :( 180 / π / (1 + $x^2) )
@define_diffrule Base.asecd(x) = :( 180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acscd(x) = :( -180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acotd(x) = :( -180 / π / (1 + $x^2) )
@define_diffrule Base.sinh(x) = :( cosh($x) )
@define_diffrule Base.cosh(x) = :( sinh($x) )
@define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 )
@define_diffrule Base.sech(x) = :( -tanh($x) * sech($x) )
@define_diffrule Base.csch(x) = :( -coth($x) * csch($x) )
@define_diffrule Base.coth(x) = :( -(csch($x)^2) )
@define_diffrule Base.asinh(x) = :( inv(sqrt($x^2 + 1)) )
@define_diffrule Base.acosh(x) = :( inv(sqrt($x^2 - 1)) )
@define_diffrule Base.atanh(x) = :( inv(1 - $x^2) )
@define_diffrule Base.asech(x) = :( -inv($x * sqrt(1 - $x^2)) )
@define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) )
@define_diffrule Base.acoth(x) = :( inv(1 - $x^2) )
@define_diffrule Base.deg2rad(x) = :( π / 180 )
@define_diffrule Base.rad2deg(x) = :( 180 / π )
@define_complex_diffrule Base.:+(x) = :( 1 )
@define_complex_diffrule Base.:-(x) = :( -1 )
@define_complex_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) )
@define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) )
@define_diffrule Base.abs2(x) = :( $x + $x )
@define_complex_diffrule Base.inv(x) = :( -(inv($x^2)) )
@define_diffrule Base.inv(x) = :( -abs2(inv($x)) )
@define_complex_diffrule Base.log(x) = :( inv($x) )
@define_complex_diffrule Base.log10(x) = :( inv($x) / log(10) )
@define_complex_diffrule Base.log2(x) = :( inv($x) / log(2) )
@define_complex_diffrule Base.log1p(x) = :( inv($x + 1) )
@define_complex_diffrule Base.exp(x) = :( exp($x) )
@define_complex_diffrule Base.exp2(x) = :( exp2($x) * log(2) )
@define_complex_diffrule Base.exp10(x) = :( exp10($x) * log(10) )
@define_complex_diffrule Base.expm1(x) = :( exp($x) )
@define_complex_diffrule Base.sin(x) = :( cos($x) )
@define_complex_diffrule Base.cos(x) = :( -sin($x) )
@define_complex_diffrule Base.tan(x) = :( 1 + tan($x)^2 )
@define_complex_diffrule Base.sec(x) = :( sec($x) * tan($x) )
@define_complex_diffrule Base.csc(x) = :( -csc($x) * cot($x) )
@define_complex_diffrule Base.cot(x) = :( -(1 + cot($x)^2) )
@define_diffrule Base.sind(x) = :( (π / 180) * cosd($x) )
@define_diffrule Base.cosd(x) = :( -(π / 180) * sind($x) )
@define_diffrule Base.tand(x) = :( (π / 180) * (1 + tand($x)^2) )
@define_diffrule Base.secd(x) = :( (π / 180) * secd($x) * tand($x) )
@define_diffrule Base.cscd(x) = :( -(π / 180) * cscd($x) * cotd($x) )
@define_diffrule Base.cotd(x) = :( -(π / 180) * (1 + cotd($x)^2) )
@define_complex_diffrule Base.sinpi(x) = :( π * cospi($x) )
@define_complex_diffrule Base.cospi(x) = :( -π * sinpi($x) )
@define_complex_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) )
@define_complex_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) )
@define_complex_diffrule Base.atan(x) = :( inv(1 + $x^2) )
@define_complex_diffrule Base.asec(x) = :( inv($x^2 * sqrt(1 - inv($x^2))) )
@define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) )
@define_complex_diffrule Base.acsc(x) = :( -inv($x^2*sqrt(1 - inv($x^2))) )
@define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) )
@define_diffrule Base.acot(x) = :( -inv(1 + $x^2) )
@define_diffrule Base.asind(x) = :( 180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.acosd(x) = :( -180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.atand(x) = :( 180 / π / (1 + $x^2) )
@define_diffrule Base.asecd(x) = :( 180 / π / $x^2 / sqrt(1 - inv($x^2)))
@define_diffrule Base.asecd(x) = :( 180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acscd(x) = :( -180 / π / $x^2 / sqrt(1 - inv($x^2)))
@define_diffrule Base.acscd(x) = :( -180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acotd(x) = :( -180 / π / (1 + $x^2) )
@define_complex_diffrule Base.sinh(x) = :( cosh($x) )
@define_complex_diffrule Base.cosh(x) = :( sinh($x) )
@define_complex_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 )
@define_complex_diffrule Base.sech(x) = :( -tanh($x) * sech($x) )
@define_complex_diffrule Base.csch(x) = :( -coth($x) * csch($x) )
@define_complex_diffrule Base.coth(x) = :( -(csch($x)^2) )
@define_complex_diffrule Base.asinh(x) = :( inv(sqrt($x^2 + 1)) )
@define_complex_diffrule Base.acosh(x) = :( inv(sqrt($x - 1)*sqrt($x+1)) )
@define_diffrule Base.acosh(x) = :( inv(sqrt($x^2 - 1)) )
@define_complex_diffrule Base.atanh(x) = :( inv(1 - $x^2) )
@define_complex_diffrule Base.asech(x) = :( -inv($x * sqrt(1 - $x^2)) )
@define_complex_diffrule Base.acsch(x) = :( -inv(sqrt($x^4 + $x^2)) )
@define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) )
@define_complex_diffrule Base.acoth(x) = :( inv(1 - $x^2) )
@define_diffrule Base.deg2rad(x) = :( π / 180 )
@define_diffrule Base.rad2deg(x) = :( 180 / π )
if VERSION < v"0.7-"
@define_diffrule Base.gamma(x) = :( digamma($x) * gamma($x) )
@define_diffrule Base.lgamma(x) = :( digamma($x) )
@define_diffrule Base.Math.JuliaLibm.log1p(x) = :( inv($x + 1) )
@define_complex_diffrule Base.gamma(x) = :( digamma($x) * gamma($x) )
@define_complex_diffrule Base.lgamma(x) = :( digamma($x) )
@define_diffrule Base.Math.JuliaLibm.log1p(x) = :( inv($x + 1) )
else
@define_diffrule SpecialFunctions.gamma(x) =
@define_complex_diffrule SpecialFunctions.gamma(x) =
:( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) )
@define_diffrule SpecialFunctions.lgamma(x) =
@define_complex_diffrule SpecialFunctions.lgamma(x) =
:( SpecialFunctions.digamma($x) )
end
@define_diffrule Base.transpose(x) = :( 1 )
@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) )
@define_complex_diffrule Base.transpose(x) = :( 1 )
@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) )

# We provide this hook for special number types like `Interval`
# that need their own special definition of `abs`.
Expand All @@ -79,12 +86,12 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
# binary #
#--------#

@define_diffrule Base.:+(x, y) = :( 1 ), :( 1 )
@define_diffrule Base.:-(x, y) = :( 1 ), :( -1 )
@define_diffrule Base.:*(x, y) = :( $y ), :( $x )
@define_diffrule Base.:/(x, y) = :( inv($y) ), :( -($x / $y / $y) )
@define_diffrule Base.:\(x, y) = :( -($y / $x / $x) ), :( inv($x) )
@define_diffrule Base.:^(x, y) = :( $y * ($x^($y - 1)) ), :( ($x^$y) * log($x) )
@define_complex_diffrule Base.:+(x, y) = :( 1 ), :( 1 )
@define_complex_diffrule Base.:-(x, y) = :( 1 ), :( -1 )
@define_complex_diffrule Base.:*(x, y) = :( $y ), :( $x )
@define_complex_diffrule Base.:/(x, y) = :( inv($y) ), :( -($x / $y / $y) )
@define_complex_diffrule Base.:\(x, y) = :( -($y / $x / $x) ), :( inv($x) )
@define_diffrule Base.:^(x, y) = :( $y * ($x^($y - 1)) ), :( ($x^$y) * log($x) )

if VERSION < v"0.7-"
@define_diffrule Base.atan2(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) )
Expand All @@ -105,38 +112,38 @@ end
# unary #
#-------#

@define_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(π)) * exp(-$x * $x) )
@define_complex_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(π)) * exp(-$x * $x) )
@define_diffrule SpecialFunctions.erfinv(x) =
:( (sqrt(π) / 2) * exp(SpecialFunctions.erfinv($x)^2) )
@define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) )
@define_diffrule SpecialFunctions.erfcinv(x) =
@define_complex_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) )
@define_diffrule SpecialFunctions.erfcinv(x) =
:( -(sqrt(π) / 2) * exp(SpecialFunctions.erfcinv($x)^2) )
@define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) )
@define_diffrule SpecialFunctions.erfcx(x) =
@define_complex_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) )
@define_complex_diffrule SpecialFunctions.erfcx(x) =
:( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(π)) )
@define_diffrule SpecialFunctions.dawson(x) =
@define_complex_diffrule SpecialFunctions.dawson(x) =
:( 1 - (2 * $x * SpecialFunctions.dawson($x)) )
@define_diffrule SpecialFunctions.digamma(x) =
@define_complex_diffrule SpecialFunctions.digamma(x) =
:( SpecialFunctions.trigamma($x) )
@define_diffrule SpecialFunctions.invdigamma(x) =
:( inv(SpecialFunctions.trigamma(SpecialFunctions.invdigamma($x))) )
@define_diffrule SpecialFunctions.trigamma(x) =
@define_complex_diffrule SpecialFunctions.trigamma(x) =
:( SpecialFunctions.polygamma(2, $x) )
@define_diffrule SpecialFunctions.airyai(x) =
@define_complex_diffrule SpecialFunctions.airyai(x) =
:( SpecialFunctions.airyaiprime($x) )
@define_diffrule SpecialFunctions.airyaiprime(x) =
@define_complex_diffrule SpecialFunctions.airyaiprime(x) =
:( $x * SpecialFunctions.airyai($x) )
@define_diffrule SpecialFunctions.airybi(x) =
@define_complex_diffrule SpecialFunctions.airybi(x) =
:( SpecialFunctions.airybiprime($x) )
@define_diffrule SpecialFunctions.airybiprime(x) =
@define_complex_diffrule SpecialFunctions.airybiprime(x) =
:( $x * SpecialFunctions.airybi($x) )
@define_diffrule SpecialFunctions.besselj0(x) =
@define_complex_diffrule SpecialFunctions.besselj0(x) =
:( -SpecialFunctions.besselj1($x) )
@define_diffrule SpecialFunctions.besselj1(x) =
@define_complex_diffrule SpecialFunctions.besselj1(x) =
:( (SpecialFunctions.besselj0($x) - SpecialFunctions.besselj(2, $x)) / 2 )
@define_diffrule SpecialFunctions.bessely0(x) =
@define_complex_diffrule SpecialFunctions.bessely0(x) =
:( -SpecialFunctions.bessely1($x) )
@define_diffrule SpecialFunctions.bessely1(x) =
@define_complex_diffrule SpecialFunctions.bessely1(x) =
:( (SpecialFunctions.bessely0($x) - SpecialFunctions.bessely(2, $x)) / 2 )

# TODO:
Expand Down
Loading