diff --git a/LICENSE b/LICENSE index 41305ea..c91599d 100644 --- a/LICENSE +++ b/LICENSE @@ -18,3 +18,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +============================================================================= + +Some parts of src/decl.jl is from Symbolics.jl and SymPy.jl licensed under the +same license with the copyrights + +Copyright (c) <2013> +Copyright (c) 2021: Shashi Gowda, Yingbo Ma, Chris Rackauckas, Julia Computing. diff --git a/src/SymEngine.jl b/src/SymEngine.jl index e5607a9..c5d6693 100644 --- a/src/SymEngine.jl +++ b/src/SymEngine.jl @@ -23,6 +23,7 @@ const libversion = get_libversion() include("exceptions.jl") include("types.jl") include("ctypes.jl") +include("decl.jl") include("display.jl") include("mathops.jl") include("mathfuns.jl") diff --git a/src/decl.jl b/src/decl.jl new file mode 100644 index 0000000..6622b08 --- /dev/null +++ b/src/decl.jl @@ -0,0 +1,150 @@ +# !!! Note: +# Many thanks to `@matthieubulte` for this contribution to `SymPy`. + +# The map_subscripts function is stolen from Symbolics.jl +const IndexMap = Dict{Char,Char}( + '-' => '₋', + '0' => '₀', + '1' => '₁', + '2' => '₂', + '3' => '₃', + '4' => '₄', + '5' => '₅', + '6' => '₆', + '7' => '₇', + '8' => '₈', + '9' => '₉') + +function map_subscripts(indices) + str = string(indices) + join(IndexMap[c] for c in str) +end + +# Define a type hierarchy to describe a variable declaration. This is mainly for convenient pattern matching later. +abstract type VarDecl end + +struct SymDecl <: VarDecl + sym :: Symbol +end + +struct NamedDecl <: VarDecl + name :: String + rest :: VarDecl +end + +struct FunctionDecl <: VarDecl + rest :: VarDecl +end + +struct TensorDecl <: VarDecl + ranges :: Vector{AbstractRange} + rest :: VarDecl +end + +struct AssumptionsDecl <: VarDecl + assumptions :: Vector{Symbol} + rest :: VarDecl +end + +# Transform a Decl struct in an Expression that calls SymPy to declare the corresponding symbol +function gendecl(x::VarDecl) + asstokw(a) = Expr(:kw, esc(a), true) + val = :($(ctor(x))($(name(x, missing)), $(map(asstokw, assumptions(x))...))) + :($(esc(sym(x))) = $(genreshape(val, x))) +end + +# Transform an expression in a Decl struct +function parsedecl(expr) + # @vars x + if isa(expr, Symbol) + return SymDecl(expr) + + # @vars x::assumptions, where assumption = assumptionkw | (assumptionkw...) + #= no assumptions in SymEngine + elseif isa(expr, Expr) && expr.head == :(::) + symexpr, assumptions = expr.args + assumptions = isa(assumptions, Symbol) ? [assumptions] : assumptions.args + return AssumptionsDecl(assumptions, parsedecl(symexpr)) + =# + + # @vars x=>"name" + elseif isa(expr, Expr) && expr.head == :call && expr.args[1] == :(=>) + length(expr.args) == 3 || parseerror() + isa(expr.args[3], String) || parseerror() + + expr, strname = expr.args[2:end] + return NamedDecl(strname, parsedecl(expr)) + + # @vars x() + elseif isa(expr, Expr) && expr.head == :call && expr.args[1] != :(=>) + length(expr.args) == 1 || parseerror() + return FunctionDecl(parsedecl(expr.args[1])) + + # @vars x[1:5, 3:9] + elseif isa(expr, Expr) && expr.head == :ref + length(expr.args) > 1 || parseerror() + ranges = map(parserange, expr.args[2:end]) + return TensorDecl(ranges, parsedecl(expr.args[1])) + else + parseerror() + end +end + +function parserange(expr) + range = eval(expr) + isa(range, AbstractRange) || parseerror() + range +end + +sym(x::SymDecl) = x.sym +sym(x::NamedDecl) = sym(x.rest) +sym(x::FunctionDecl) = sym(x.rest) +sym(x::TensorDecl) = sym(x.rest) +sym(x::AssumptionsDecl) = sym(x.rest) + +ctor(::SymDecl) = :symbols +ctor(x::NamedDecl) = ctor(x.rest) +ctor(::FunctionDecl) = :SymFunction +ctor(x::TensorDecl) = ctor(x.rest) +ctor(x::AssumptionsDecl) = ctor(x.rest) + +assumptions(::SymDecl) = [] +assumptions(x::NamedDecl) = assumptions(x.rest) +assumptions(x::FunctionDecl) = assumptions(x.rest) +assumptions(x::TensorDecl) = assumptions(x.rest) +assumptions(x::AssumptionsDecl) = x.assumptions + +# Reshape is not used by most nodes, but TensorNodes require the output to be given +# the shape matching the specification. For instance if @vars x[1:3, 2:6], we should +# have size(x) = (3, 5) +genreshape(expr, ::SymDecl) = expr +genreshape(expr, x::NamedDecl) = genreshape(expr, x.rest) +genreshape(expr, x::FunctionDecl) = genreshape(expr, x.rest) +genreshape(expr, x::TensorDecl) = let + shape = tuple(length.(x.ranges)...) + :(reshape(collect($(expr)), $(shape))) +end +genreshape(expr, x::AssumptionsDecl) = genreshape(expr, x.rest) + +# To find out the name, we need to traverse in both directions to make sure that each node can get +# information from parents and children about possible name. +# This is done because the expr tree will always look like NamedDecl -> ... -> TensorDecl -> ... -> SymDecl +# and the TensorDecl node will need to know if it should create names base on a NamedDecl parent or +# based on the SymDecl leaf. +name(x::SymDecl, parentname) = coalesce(parentname, String(x.sym)) +name(x::NamedDecl, parentname) = coalesce(name(x.rest, x.name), x.name) +name(x::FunctionDecl, parentname) = name(x.rest, parentname) +name(x::AssumptionsDecl, parentname) = name(x.rest, parentname) +name(x::TensorDecl, parentname) = let + basename = name(x.rest, parentname) + # we need to double reverse the indices to make sure that we traverse them in the natural order + namestensor = map(Iterators.product(x.ranges...)) do ind + sub = join(map(map_subscripts, ind), "_") + string(basename, sub) + end + join(namestensor[:], ", ") +end + +function parseerror() + error("Incorrect @vars syntax. Try `@vars x=>\"x₀\" y() z[0:4]` for instance.") +end diff --git a/src/mathfuns.jl b/src/mathfuns.jl index 72e30a9..6cf7880 100644 --- a/src/mathfuns.jl +++ b/src/mathfuns.jl @@ -26,38 +26,44 @@ end ## import from base one argument functions ## these are from cwrapper.cpp, one arg func for (meth, libnm, modu) in [ - (:abs,:abs,:Base), - (:sin,:sin,:Base), - (:cos,:cos,:Base), - (:tan,:tan,:Base), - (:csc,:csc,:Base), - (:sec,:sec,:Base), - (:cot,:cot,:Base), - (:asin,:asin,:Base), - (:acos,:acos,:Base), - (:asec,:asec,:Base), - (:acsc,:acsc,:Base), - (:atan,:atan,:Base), - (:acot,:acot,:Base), - (:sinh,:sinh,:Base), - (:cosh,:cosh,:Base), - (:tanh,:tanh,:Base), - (:csch,:csch,:Base), - (:sech,:sech,:Base), - (:coth,:coth,:Base), - (:asinh,:asinh,:Base), - (:acosh,:acosh,:Base), - (:asech,:asech,:Base), - (:acsch,:acsch,:Base), - (:atanh,:atanh,:Base), - (:acoth,:acoth,:Base), - (:gamma,:gamma,:SpecialFunctions), - (:log,:log,:Base), - (:sqrt,:sqrt,:Base), - (:exp,:exp,:Base), - (:eta,:dirichlet_eta,:SpecialFunctions), - (:zeta,:zeta,:SpecialFunctions), - ] + (:abs,:abs,:Base), + (:sin,:sin,:Base), + (:cos,:cos,:Base), + (:tan,:tan,:Base), + (:csc,:csc,:Base), + (:sec,:sec,:Base), + (:cot,:cot,:Base), + (:asin,:asin,:Base), + (:acos,:acos,:Base), + (:asec,:asec,:Base), + (:acsc,:acsc,:Base), + (:atan,:atan,:Base), + (:acot,:acot,:Base), + (:sinh,:sinh,:Base), + (:cosh,:cosh,:Base), + (:tanh,:tanh,:Base), + (:csch,:csch,:Base), + (:sech,:sech,:Base), + (:coth,:coth,:Base), + (:asinh,:asinh,:Base), + (:acosh,:acosh,:Base), + (:asech,:asech,:Base), + (:acsch,:acsch,:Base), + (:atanh,:atanh,:Base), + (:acoth,:acoth,:Base), + (:log,:log,:Base), + (:sqrt,:sqrt,:Base), + (:cbrt,:cbrt,:Base), + (:exp,:exp,:Base), + (:floor,:floor,:Base), + (:ceil, :ceiling,:Base), + (:erf, :erf, :SpecialFunctions), + (:erfc, :erfc, :SpecialFunctions), + (:eta,:dirichlet_eta,:SpecialFunctions), + (:gamma,:gamma,:SpecialFunctions), + (:loggamma,:loggamma,:SpecialFunctions), + (:zeta,:zeta,:SpecialFunctions), +] eval(:(import $modu.$meth)) IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm) end @@ -69,22 +75,26 @@ if get_symbol(:basic_atan2) != C_NULL end # export not import -for (meth, libnm) in [ - (:lambertw,:lambertw), # in add-on packages, not base +for (meth, libnm) in [ # in add-on packages, not base + (:lambertw,:lambertw) ] IMPLEMENT_ONE_ARG_FUNC(meth, libnm) eval(Expr(:export, meth)) end -## add these in until they are wrapped -Base.cbrt(a::SymbolicType) = a^(1//3) - +# d functions for (meth, fn) in [(:sind, :sin), (:cosd, :cos), (:tand, :tan), (:secd, :sec), (:cscd, :csc), (:cotd, :cot)] eval(:(import Base.$meth)) @eval begin $(meth)(a::SymbolicType) = $(fn)(a*PI/180) end end +for (meth, fn) in [(:asind, :asin), (:acosd, :acos), (:atand, :atan), (:asecd, :asec), (:acscd, :acsc), (:acotd, :acot)] + eval(:(import Base.$meth)) + @eval begin + $(meth)(a::SymbolicType) = $(fn)(a) * 180/PI + end +end ## Number theory module from cppwrapper @@ -97,7 +107,8 @@ for (meth, libnm) in [(:gcd, :gcd), IMPLEMENT_TWO_ARG_FUNC(:(Base.$meth), libnm, lib=:ntheory_) end -Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k)) #ntheory_binomial seems wrong +#import Base: binomial; IMPLEMENT_TWO_ARG_FUNC(:binomial, :binomial, lib=:ntheory_ ) #ntheory_binomial seems wrong +Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k)) Base.binomial(n::Basic, k::Integer) = binomial(N(n), N(k)) #Fix dispatch ambiguity / MethodError Base.rem(a::SymbolicType, b::SymbolicType) = a - (a ÷ b) * b Base.factorial(n::SymbolicType, k) = factorial(N(n), N(k)) @@ -109,6 +120,7 @@ for (meth, libnm) in [(:nextprime,:nextprime) eval(Expr(:export, meth)) end + function Base.convert(::Type{CVecBasic}, x::Vector{T}) where T vec = CVecBasic() for i in x diff --git a/src/numerics.jl b/src/numerics.jl index 115a785..3f4e363 100644 --- a/src/numerics.jl +++ b/src/numerics.jl @@ -200,7 +200,7 @@ Base.Real(x::Basic) = convert(Real, x) ## For generic programming in Julia float(x::Basic) = float(N(x)) -# trunc, flooor, ceil, round, rem, mod, cld, fld, +# trunc, round, rem, mod, cld, fld, isfinite(x::Basic) = x-x == 0 isnan(x::Basic) = ( x == NAN ) isinf(x::Basic) = !isnan(x) & !isfinite(x) @@ -211,11 +211,6 @@ isless(x::Basic, y::Basic) = isless(N(x), N(y)) trunc(x::Basic, args...) = Basic(trunc(N(x), args...)) trunc(::Type{T},x::Basic, args...) where {T <: Integer} = convert(T, trunc(x,args...)) -ceil(x::Basic) = Basic(ceil(N(x))) -ceil(::Type{T},x::Basic) where {T <: Integer} = convert(T, ceil(x)) - -floor(x::Basic) = Basic(floor(N(x))) -floor(::Type{T},x::Basic) where {T <: Integer} = convert(T, floor(x)) round(x::Basic; kwargs...) = Basic(round(N(x); kwargs...)) round(::Type{T},x::Basic; kwargs...) where {T <: Integer} = convert(T, round(x; kwargs...)) diff --git a/src/types.jl b/src/types.jl index d778a69..2b4dced 100644 --- a/src/types.jl +++ b/src/types.jl @@ -145,30 +145,44 @@ end ## Follow, somewhat, the python names: symbols to construct symbols, @vars """ -Macro to define 1 or more variables in the main workspace. + @vars x y[1:5] z() -Symbolic values are defined with `_symbol`. This is a convenience +Macro to define 1 or more variables or symbolic function Example ``` @vars x y z +@vars x[1:4] +@vars u(), x ``` + """ -macro vars(x...) - q=Expr(:block) - if length(x) == 1 && isa(x[1],Expr) - @assert x[1].head === :tuple "@syms expected a list of symbols" - x = x[1].args +macro vars(xs...) + # If the user separates declaration with commas, the top-level expression is a tuple + if length(xs) == 1 && isa(xs[1], Expr) && xs[1].head == :tuple + _gensyms(xs[1].args...) + elseif length(xs) > 0 + _gensyms(xs...) end - for s in x - @assert isa(s,Symbol) "@syms expected a list of symbols" - push!(q.args, Expr(:(=), esc(s), Expr(:call, :(SymEngine._symbol), Expr(:quote, s)))) +end + +function _gensyms(xs...) + asstokw(a) = Expr(:kw, esc(a), true) + + # Each declaration is parsed and generates a declaration using `symbols` + symdefs = map(xs) do expr + decl = parsedecl(expr) + symname = sym(decl) + symname, gendecl(decl) end - push!(q.args, Expr(:tuple, map(esc, x)...)) - q + syms, defs = collect(zip(symdefs...)) + + # The macro returns a tuple of Symbols that were declared + Expr(:block, defs..., :(tuple($(map(esc,syms)...)))) end + ## We also have a wrapper type that can be used to control dispatch ## pros: wrapping adds overhead, so if possible best to use Basic ## cons: have to write methods meth(x::Basic, ...) = meth(BasicType(x),...) @@ -305,4 +319,3 @@ function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{B throw_if_error(res) return a end - diff --git a/test/runtests.jl b/test/runtests.jl index bc7e5d3..b6132a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using SymEngine +using SpecialFunctions using Compat using Test using Serialization @@ -20,6 +21,13 @@ end @test_throws UndefVarError isdefined(w) @test_throws Exception show(Basic()) +# test @vars constructions +@vars a, b[0:4], c(), d=>"D" +@test length(b) == 5 +@test isa(c, SymFunction) +@test repr(d) == "D" + + a = x^2 + x/2 - x*y*5 b = diff(a, x) @test b == 2*x + 1//2 - 5*y @@ -62,10 +70,9 @@ c = Basic(-5) @test abs(c) == 5 @test abs(c) != 4 -show(a) -println() -show(b) -println() +# test show +repr("text/plain", a) == (1/2)*x - 5*x*y + x^2 +repr("text/plain", b) == 1/2 + 2*x - 5*y @test 1 // x == 1 / x @test x // 2 == (1//2) * x @@ -79,6 +86,65 @@ println() @test subs(sin(x), x, pi) == 0 @test sind(Basic(30)) == 1 // 2 +# symbolic functions +@testset "mathfuns" begin + @vars a, b + @testset for fn ∈ (sin, cos, tan, + csc, sec, cot, + asin, acos, atan, + acsc, asec, acot, + sinh, cosh, tan, + csch, sech, coth, + asinh, acosh, atanh, + acsch, asech, acoth, + sind, cosd, tand, + cscd, secd, cotd, + asind, acosd, atand, + acscd, asecd, acotd, + abs, log, exp, + sqrt, cbrt, + floor, ceil, + erf, erfc, gamma + ) + @test isa(fn(a), Basic) + u = fn(Basic(1/2)) + N(u) # can evaluate + end + + # evalf fails on these + for fn ∈ ( lambertw, eta, loggamma, zeta) + @test isa(fn(a), Basic) + u = fn(Basic(1/2)) + @test_broken N(u) + end + + # two arg work on numeric values, not symbols + u, v = Basic(10), Basic(3) + + @test_broken gcd(a*b, a^2) == a + @test gcd(u, v) == 1 + + @test_broken lcm(a, a*b) == a*b + @test lcm(u, v) == 30 + + @test_throws DivideError div(a,b) + @test div(u, v) == 3 + + @test_throws DivideError mod(a,b) + @test mod(u, v) == 1 + + @test_throws DivideError rem(a,b) + @test rem(u, v) == 1 + + @test_throws DivideError divrem(a,b) + @test divrem(u, v) == (3, 1) + + @test binomial(u, v) == binomial(10,3) + + @test nextprime(u) == 11 +end + + ## calculus x,y = symbols("x y") n = Basic(2)