Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

At vars #266

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/SymEngine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const libversion = get_libversion()
include("exceptions.jl")
include("types.jl")
include("ctypes.jl")
include("decl.jl")
include("display.jl")
include("mathops.jl")
include("mathfuns.jl")
Expand Down
150 changes: 150 additions & 0 deletions src/decl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# !!! Note:
# Many thanks to `@matthieubulte` for this contribution to `SymPy`.

# The map_subscripts function is stolen from Symbolics.jl
Comment on lines +1 to +4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the licenses for these codes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const IndexMap = Dict{Char,Char}(
'-' => '₋',
'0' => '₀',
'1' => '₁',
'2' => '₂',
'3' => '₃',
'4' => '₄',
'5' => '₅',
'6' => '₆',
'7' => '₇',
'8' => '₈',
'9' => '₉')

function map_subscripts(indices)
str = string(indices)
join(IndexMap[c] for c in str)
end

# Define a type hierarchy to describe a variable declaration. This is mainly for convenient pattern matching later.
abstract type VarDecl end

struct SymDecl <: VarDecl
sym :: Symbol
end

struct NamedDecl <: VarDecl
name :: String
rest :: VarDecl
end

struct FunctionDecl <: VarDecl
rest :: VarDecl
end

struct TensorDecl <: VarDecl
ranges :: Vector{AbstractRange}
rest :: VarDecl
end

struct AssumptionsDecl <: VarDecl
assumptions :: Vector{Symbol}
rest :: VarDecl
end

# Transform a Decl struct in an Expression that calls SymPy to declare the corresponding symbol
function gendecl(x::VarDecl)
asstokw(a) = Expr(:kw, esc(a), true)
val = :($(ctor(x))($(name(x, missing)), $(map(asstokw, assumptions(x))...)))
:($(esc(sym(x))) = $(genreshape(val, x)))
end

# Transform an expression in a Decl struct
function parsedecl(expr)
# @vars x
if isa(expr, Symbol)
return SymDecl(expr)

# @vars x::assumptions, where assumption = assumptionkw | (assumptionkw...)
#= no assumptions in SymEngine
elseif isa(expr, Expr) && expr.head == :(::)
symexpr, assumptions = expr.args
assumptions = isa(assumptions, Symbol) ? [assumptions] : assumptions.args
return AssumptionsDecl(assumptions, parsedecl(symexpr))
=#

# @vars x=>"name"
elseif isa(expr, Expr) && expr.head == :call && expr.args[1] == :(=>)
length(expr.args) == 3 || parseerror()
isa(expr.args[3], String) || parseerror()

expr, strname = expr.args[2:end]
return NamedDecl(strname, parsedecl(expr))

# @vars x()
elseif isa(expr, Expr) && expr.head == :call && expr.args[1] != :(=>)
length(expr.args) == 1 || parseerror()
return FunctionDecl(parsedecl(expr.args[1]))

# @vars x[1:5, 3:9]
elseif isa(expr, Expr) && expr.head == :ref
length(expr.args) > 1 || parseerror()
ranges = map(parserange, expr.args[2:end])
return TensorDecl(ranges, parsedecl(expr.args[1]))
else
parseerror()
end
end

function parserange(expr)
range = eval(expr)
isa(range, AbstractRange) || parseerror()
range
end

sym(x::SymDecl) = x.sym
sym(x::NamedDecl) = sym(x.rest)
sym(x::FunctionDecl) = sym(x.rest)
sym(x::TensorDecl) = sym(x.rest)
sym(x::AssumptionsDecl) = sym(x.rest)

ctor(::SymDecl) = :symbols
ctor(x::NamedDecl) = ctor(x.rest)
ctor(::FunctionDecl) = :SymFunction
ctor(x::TensorDecl) = ctor(x.rest)
ctor(x::AssumptionsDecl) = ctor(x.rest)

assumptions(::SymDecl) = []
assumptions(x::NamedDecl) = assumptions(x.rest)
assumptions(x::FunctionDecl) = assumptions(x.rest)
assumptions(x::TensorDecl) = assumptions(x.rest)
assumptions(x::AssumptionsDecl) = x.assumptions

# Reshape is not used by most nodes, but TensorNodes require the output to be given
# the shape matching the specification. For instance if @vars x[1:3, 2:6], we should
# have size(x) = (3, 5)
genreshape(expr, ::SymDecl) = expr
genreshape(expr, x::NamedDecl) = genreshape(expr, x.rest)
genreshape(expr, x::FunctionDecl) = genreshape(expr, x.rest)
genreshape(expr, x::TensorDecl) = let
shape = tuple(length.(x.ranges)...)
:(reshape(collect($(expr)), $(shape)))
end
genreshape(expr, x::AssumptionsDecl) = genreshape(expr, x.rest)

# To find out the name, we need to traverse in both directions to make sure that each node can get
# information from parents and children about possible name.
# This is done because the expr tree will always look like NamedDecl -> ... -> TensorDecl -> ... -> SymDecl
# and the TensorDecl node will need to know if it should create names base on a NamedDecl parent or
# based on the SymDecl leaf.
name(x::SymDecl, parentname) = coalesce(parentname, String(x.sym))
name(x::NamedDecl, parentname) = coalesce(name(x.rest, x.name), x.name)
name(x::FunctionDecl, parentname) = name(x.rest, parentname)
name(x::AssumptionsDecl, parentname) = name(x.rest, parentname)
name(x::TensorDecl, parentname) = let
basename = name(x.rest, parentname)
# we need to double reverse the indices to make sure that we traverse them in the natural order
namestensor = map(Iterators.product(x.ranges...)) do ind
sub = join(map(map_subscripts, ind), "_")
string(basename, sub)
end
join(namestensor[:], ", ")
end

function parseerror()
error("Incorrect @vars syntax. Try `@vars x=>\"x₀\" y() z[0:4]` for instance.")
end
88 changes: 50 additions & 38 deletions src/mathfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,44 @@ end
## import from base one argument functions
## these are from cwrapper.cpp, one arg func
for (meth, libnm, modu) in [
(:abs,:abs,:Base),
(:sin,:sin,:Base),
(:cos,:cos,:Base),
(:tan,:tan,:Base),
(:csc,:csc,:Base),
(:sec,:sec,:Base),
(:cot,:cot,:Base),
(:asin,:asin,:Base),
(:acos,:acos,:Base),
(:asec,:asec,:Base),
(:acsc,:acsc,:Base),
(:atan,:atan,:Base),
(:acot,:acot,:Base),
(:sinh,:sinh,:Base),
(:cosh,:cosh,:Base),
(:tanh,:tanh,:Base),
(:csch,:csch,:Base),
(:sech,:sech,:Base),
(:coth,:coth,:Base),
(:asinh,:asinh,:Base),
(:acosh,:acosh,:Base),
(:asech,:asech,:Base),
(:acsch,:acsch,:Base),
(:atanh,:atanh,:Base),
(:acoth,:acoth,:Base),
(:gamma,:gamma,:SpecialFunctions),
(:log,:log,:Base),
(:sqrt,:sqrt,:Base),
(:exp,:exp,:Base),
(:eta,:dirichlet_eta,:SpecialFunctions),
(:zeta,:zeta,:SpecialFunctions),
]
(:abs,:abs,:Base),
(:sin,:sin,:Base),
(:cos,:cos,:Base),
(:tan,:tan,:Base),
(:csc,:csc,:Base),
(:sec,:sec,:Base),
(:cot,:cot,:Base),
(:asin,:asin,:Base),
(:acos,:acos,:Base),
(:asec,:asec,:Base),
(:acsc,:acsc,:Base),
(:atan,:atan,:Base),
(:acot,:acot,:Base),
(:sinh,:sinh,:Base),
(:cosh,:cosh,:Base),
(:tanh,:tanh,:Base),
(:csch,:csch,:Base),
(:sech,:sech,:Base),
(:coth,:coth,:Base),
(:asinh,:asinh,:Base),
(:acosh,:acosh,:Base),
(:asech,:asech,:Base),
(:acsch,:acsch,:Base),
(:atanh,:atanh,:Base),
(:acoth,:acoth,:Base),
(:log,:log,:Base),
(:sqrt,:sqrt,:Base),
(:cbrt,:cbrt,:Base),
(:exp,:exp,:Base),
(:floor,:floor,:Base),
(:ceil, :ceiling,:Base),
(:erf, :erf, :SpecialFunctions),
(:erfc, :erfc, :SpecialFunctions),
(:eta,:dirichlet_eta,:SpecialFunctions),
(:gamma,:gamma,:SpecialFunctions),
(:loggamma,:loggamma,:SpecialFunctions),
(:zeta,:zeta,:SpecialFunctions),
]
eval(:(import $modu.$meth))
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
end
Expand All @@ -69,22 +75,26 @@ if get_symbol(:basic_atan2) != C_NULL
end

# export not import
for (meth, libnm) in [
(:lambertw,:lambertw), # in add-on packages, not base
for (meth, libnm) in [ # in add-on packages, not base
(:lambertw,:lambertw)
]
IMPLEMENT_ONE_ARG_FUNC(meth, libnm)
eval(Expr(:export, meth))
end

## add these in until they are wrapped
Base.cbrt(a::SymbolicType) = a^(1//3)

# d functions
for (meth, fn) in [(:sind, :sin), (:cosd, :cos), (:tand, :tan), (:secd, :sec), (:cscd, :csc), (:cotd, :cot)]
eval(:(import Base.$meth))
@eval begin
$(meth)(a::SymbolicType) = $(fn)(a*PI/180)
end
end
for (meth, fn) in [(:asind, :asin), (:acosd, :acos), (:atand, :atan), (:asecd, :asec), (:acscd, :acsc), (:acotd, :acot)]
eval(:(import Base.$meth))
@eval begin
$(meth)(a::SymbolicType) = $(fn)(a) * 180/PI
end
end


## Number theory module from cppwrapper
Expand All @@ -97,7 +107,8 @@ for (meth, libnm) in [(:gcd, :gcd),
IMPLEMENT_TWO_ARG_FUNC(:(Base.$meth), libnm, lib=:ntheory_)
end

Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k)) #ntheory_binomial seems wrong
#import Base: binomial; IMPLEMENT_TWO_ARG_FUNC(:binomial, :binomial, lib=:ntheory_ ) #ntheory_binomial seems wrong
Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k))
Base.binomial(n::Basic, k::Integer) = binomial(N(n), N(k)) #Fix dispatch ambiguity / MethodError
Base.rem(a::SymbolicType, b::SymbolicType) = a - (a ÷ b) * b
Base.factorial(n::SymbolicType, k) = factorial(N(n), N(k))
Expand All @@ -109,6 +120,7 @@ for (meth, libnm) in [(:nextprime,:nextprime)
eval(Expr(:export, meth))
end


function Base.convert(::Type{CVecBasic}, x::Vector{T}) where T
vec = CVecBasic()
for i in x
Expand Down
7 changes: 1 addition & 6 deletions src/numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ Base.Real(x::Basic) = convert(Real, x)
## For generic programming in Julia
float(x::Basic) = float(N(x))

# trunc, flooor, ceil, round, rem, mod, cld, fld,
# trunc, round, rem, mod, cld, fld,
isfinite(x::Basic) = x-x == 0
isnan(x::Basic) = ( x == NAN )
isinf(x::Basic) = !isnan(x) & !isfinite(x)
Expand All @@ -211,11 +211,6 @@ isless(x::Basic, y::Basic) = isless(N(x), N(y))
trunc(x::Basic, args...) = Basic(trunc(N(x), args...))
trunc(::Type{T},x::Basic, args...) where {T <: Integer} = convert(T, trunc(x,args...))

ceil(x::Basic) = Basic(ceil(N(x)))
ceil(::Type{T},x::Basic) where {T <: Integer} = convert(T, ceil(x))

floor(x::Basic) = Basic(floor(N(x)))
floor(::Type{T},x::Basic) where {T <: Integer} = convert(T, floor(x))

round(x::Basic; kwargs...) = Basic(round(N(x); kwargs...))
round(::Type{T},x::Basic; kwargs...) where {T <: Integer} = convert(T, round(x; kwargs...))
Expand Down
39 changes: 26 additions & 13 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,30 +145,44 @@ end
## Follow, somewhat, the python names: symbols to construct symbols, @vars

"""
Macro to define 1 or more variables in the main workspace.
@vars x y[1:5] z()

Symbolic values are defined with `_symbol`. This is a convenience
Macro to define 1 or more variables or symbolic function

Example
```
@vars x y z
@vars x[1:4]
@vars u(), x
```

"""
macro vars(x...)
q=Expr(:block)
if length(x) == 1 && isa(x[1],Expr)
@assert x[1].head === :tuple "@syms expected a list of symbols"
x = x[1].args
macro vars(xs...)
# If the user separates declaration with commas, the top-level expression is a tuple
if length(xs) == 1 && isa(xs[1], Expr) && xs[1].head == :tuple
_gensyms(xs[1].args...)
elseif length(xs) > 0
_gensyms(xs...)
end
for s in x
@assert isa(s,Symbol) "@syms expected a list of symbols"
push!(q.args, Expr(:(=), esc(s), Expr(:call, :(SymEngine._symbol), Expr(:quote, s))))
end

function _gensyms(xs...)
asstokw(a) = Expr(:kw, esc(a), true)

# Each declaration is parsed and generates a declaration using `symbols`
symdefs = map(xs) do expr
decl = parsedecl(expr)
symname = sym(decl)
symname, gendecl(decl)
end
push!(q.args, Expr(:tuple, map(esc, x)...))
q
syms, defs = collect(zip(symdefs...))

# The macro returns a tuple of Symbols that were declared
Expr(:block, defs..., :(tuple($(map(esc,syms)...))))
end



## We also have a wrapper type that can be used to control dispatch
## pros: wrapping adds overhead, so if possible best to use Basic
## cons: have to write methods meth(x::Basic, ...) = meth(BasicType(x),...)
Expand Down Expand Up @@ -305,4 +319,3 @@ function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{B
throw_if_error(res)
return a
end

Loading
Loading