Skip to content

Commit

Permalink
Merge pull request #549 from JuliaSymbolics/s/newexpr
Browse files Browse the repository at this point in the history
Updates for Unityper types
  • Loading branch information
shashi authored Jan 14, 2023
2 parents a3fa921 + 1189110 commit 180b178
Show file tree
Hide file tree
Showing 28 changed files with 282 additions and 246 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -32,7 +31,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"

[compat]
Expand All @@ -48,7 +46,6 @@ IfElse = "0.1"
LaTeXStrings = "1.3"
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15"
MacroTools = "0.5"
Metatheory = "1.2.0"
NaNMath = "0.3, 1"
RecipesBase = "1.1"
Reexport = "0.2, 1"
Expand All @@ -59,8 +56,7 @@ SciMLBase = "1.8"
Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "1.1"
SymbolicUtils = "0.18, 0.19"
TermInterface = "0.2, 0.3"
SymbolicUtils = "1.0.1"
TreeViews = "0.3"
LambertW = "0.4.5"
julia = "1.6"
Expand Down
12 changes: 4 additions & 8 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ $(DocStringExtensions.README)
"""
module Symbolics

using TermInterface

using Metatheory

using DocStringExtensions, Markdown

using LinearAlgebra
Expand All @@ -20,15 +16,15 @@ using Setfield
import DomainSets: Domain
@reexport using SymbolicUtils

import TermInterface: similarterm, istree, operation, arguments, symtype
import SymbolicUtils: similarterm, istree, operation, arguments, symtype

import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div,
import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic,
FnType, @rule, Rewriters, substitute,
promote_symtype
promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv

using SymbolicUtils.Code

import Metatheory.Rewriters: Chain, Prewalk, Postwalk, Fixpoint
import SymbolicUtils.Rewriters: Chain, Prewalk, Postwalk, Fixpoint

import SymbolicUtils.Code: toexpr

Expand Down
16 changes: 9 additions & 7 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function Base.getindex(x::SymArray, idx...)
else
input_idx = []
output_idx = []
ranges = Dict{Sym, AbstractRange}()
ranges = Dict{BasicSymbolic, AbstractRange}()
subscripts = makesubscripts(length(idx))
for (j, i) in enumerate(idx)
if symtype(i) <: Integer
Expand Down Expand Up @@ -98,7 +98,7 @@ end

import Base: +, -
tup(c::CartesianIndex) = Tuple(c)
tup(c::Term{CartesianIndex}) = arguments(c)
tup(c::Symbolic{CartesianIndex}) = istree(c) ? arguments(c) : error("Cartesian index not found")
@wrapped function -(x::CartesianIndex, y::CartesianIndex)
CartesianIndex((tup(x) .- tup(y))...)
end
Expand Down Expand Up @@ -224,15 +224,17 @@ isdot(A, b) = isadjointvec(A) && ndims(b) == 1
isadjointvec(A::Adjoint) = ndims(parent(A)) == 1
isadjointvec(A::Transpose) = ndims(parent(A)) == 1

function isadjointvec(A::Term)
(operation(A) === (adjoint) ||
operation(A) == (transpose)) && ndims(arguments(A)[1]) == 1
function isadjointvec(A)
if istree(A)
(operation(A) === (adjoint) ||
operation(A) == (transpose)) && ndims(arguments(A)[1]) == 1
else
false
end
end

isadjointvec(A::ArrayOp) = isadjointvec(A.term)

isadjointvec(A) = false

# TODO: add more such methods
function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...)
Term{eltype(A)}(getindex, [A, i, ii...])
Expand Down
26 changes: 15 additions & 11 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct ArrayOp{T<:AbstractArray} <: Symbolic{T}
reduce
term
shape
ranges::Dict{Sym, AbstractRange} # index range each index symbol can take,
ranges::Dict{BasicSymbolic, AbstractRange} # index range each index symbol can take,
# optional for each symbol
metadata
end
Expand Down Expand Up @@ -197,7 +197,7 @@ function make_shape(output_idx, expr, ranges=Dict())
end

sz = map(output_idx) do i
if i isa Sym
if issym(i)
if haskey(ranges, i)
return axes(ranges[i], 1)
end
Expand All @@ -222,7 +222,7 @@ end


function ranges(a::ArrayOp)
rs = Dict{Sym, Any}()
rs = Dict{BasicSymbolic, Any}()
ax = idx_to_axes(a.expr)
for i in keys(ax)
if haskey(a.ranges, i)
Expand Down Expand Up @@ -316,7 +316,7 @@ get_extents(x::AbstractRange) = x
# dim: The dimension of the array indexed
# boundary: how much padding is this indexing requiring, for example
# boundary is 2 for x[i + 2], and boundary = -2 for x[i - 2]
function idx_to_axes(expr, dict=Dict{Sym, Vector}(), ranges=Dict())
function idx_to_axes(expr, dict=Dict{Any, Vector}(), ranges=Dict())
if istree(expr)
if operation(expr) === (getindex)
args = arguments(expr)
Expand Down Expand Up @@ -376,7 +376,7 @@ function arrterm(f, args...)
atype{etype, nd}
end

setmetadata(Term{S}(f, args),
setmetadata(Term{S}(f, Any[args...]),
ArrayShapeCtx,
propagate_shape(f, args...))
end
Expand Down Expand Up @@ -461,7 +461,7 @@ const ArrayLike{T,N} = Union{
ArrayOp{AbstractArray{T,N}},
Symbolic{AbstractArray{T,N}},
Arr{T,N},
SymbolicUtils.Term{Arr{T, N}}
SymbolicUtils.Term{AbstractArray{T, N}}
} # Like SymArray but includes Arr and Term{Arr}

unwrap(x::Arr) = x.value
Expand Down Expand Up @@ -599,8 +599,12 @@ function scalarize(arr::AbstractArray, idx)
arr[idx...]
end

function scalarize(arr::Term, idx)
scalarize_op(operation(arr), arr, idx)
function scalarize(arr, idx)
if istree(arr)
scalarize_op(operation(arr), arr, idx)
else
error("scalarize is not defined for $arr at idx=$idx")
end
end

scalarize_op(f, arr) = arr
Expand Down Expand Up @@ -748,9 +752,9 @@ function arraymaker(T, shape, views, seq...)
ArrayMaker{T}(shape, [(views .=> seq)...], nothing)
end

TermInterface.istree(x::ArrayMaker) = true
TermInterface.operation(x::ArrayMaker) = arraymaker
TermInterface.arguments(x::ArrayMaker) = [eltype(x), shape(x), map(first, x.sequence), map(last, x.sequence)...]
istree(x::ArrayMaker) = true
operation(x::ArrayMaker) = arraymaker
arguments(x::ArrayMaker) = [eltype(x), shape(x), map(first, x.sequence), map(last, x.sequence)...]

shape(am::ArrayMaker) = am.shape

Expand Down
12 changes: 6 additions & 6 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ function _build_function(target::JuliaTarget, op, args...;
cse = false, kwargs...)
dargs = map((x) -> destructure_arg(x[2], !checkbounds, Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))
expr = if cse
fun = Func(dargs, [], Code.cse(op))
fun = Func(dargs, [], Code.cse(unwrap(op)))
(wrap_code !== nothing) && (fun = wrap_code(fun))
toexpr(fun, states)
else
Expand Down Expand Up @@ -142,7 +142,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))

expr = if cse
toexpr(Func(dargs, [], Code.cse(op)), states)
toexpr(Func(dargs, [], Code.cse(unwrap(op))), states)
else
toexpr(Func(dargs, [], op), states)
end
Expand Down Expand Up @@ -429,7 +429,7 @@ function _make_array(rhss::AbstractArray, similarto, cse)
if _issparse(arr)
_make_sparse_array(arr, similarto, cse)
elseif cse
Code.cse(MakeArray(arr, similarto))
Code.cse(MakeArray(unwrap.(arr), similarto))
else
MakeArray(arr, similarto)
end
Expand Down Expand Up @@ -558,7 +558,7 @@ function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1],
states = LazyState(),
lhsname=:du,rhsnames=[Symbol("MTK$i") for i in 1:length(args)])
O = value(O)
if (O isa Sym || isa(operation(O), Sym)) || (istree(O) && operation(O) == getindex)
if (issym(O) || issym(operation(O))) || (istree(O) && operation(O) == getindex)
(j,i) = get(varnumbercache, O, (nothing, nothing))
if !isnothing(j)
return i==0 ? :($(rhsnames[j])) : :($(rhsnames[j])[$(i+offset)])
Expand All @@ -572,7 +572,7 @@ function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1],
Expr(:call, Symbol(operation(O)), (numbered_expr(x,varnumbercache,args...;offset=offset,lhsname=lhsname,
rhsnames=rhsnames,varordering=varordering) for x in arguments(O))...)
end
elseif O isa Sym
elseif issym(O)
tosymbol(O, escape=false)
else
O
Expand All @@ -584,7 +584,7 @@ function numbered_expr(de::Equation,varnumbercache,args...;varordering = args[1]

varordering = value.(args[1])
var = var_from_nested_derivative(de.lhs)[1]
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : operation(x), escape=false), tosymbol(var, escape=false)),varordering)
i = findfirst(x->isequal(tosymbol(issym(x) ? x : operation(x), escape=false), tosymbol(var, escape=false)),varordering)
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,varnumbercache,args...;offset=offset,
varordering = varordering,
lhsname = lhsname,
Expand Down
12 changes: 6 additions & 6 deletions src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ function wrapper_type(::Type{Complex{T}}) where T
Symbolics.has_symwrapper(T) ? Complex{wrapper_type(T)} : Complex{T}
end

TermInterface.symtype(a::ComplexTerm{T}) where T = Complex{T}
TermInterface.istree(a::ComplexTerm) = true
TermInterface.operation(a::ComplexTerm{T}) where T = Complex{T}
TermInterface.arguments(a::ComplexTerm) = [a.re, a.im]
symtype(a::ComplexTerm{T}) where T = Complex{T}
istree(a::ComplexTerm) = true
operation(a::ComplexTerm{T}) where T = Complex{T}
arguments(a::ComplexTerm) = [a.re, a.im]

function TermInterface.similarterm(t::ComplexTerm, f, args, symtype; metadata=nothing, exprhead=exprhead(t))
function similarterm(t::ComplexTerm, f, args, symtype; metadata=nothing)
if f <: Complex
ComplexTerm{real(f)}(args...)
else
similarterm(first(args), f, args, symtype; metadata=metadata, exprhead=exprhead)
similarterm(first(args), f, args, symtype; metadata=metadata)
end
end

Expand Down
33 changes: 17 additions & 16 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ end
(D::Differential)(x::Num) = Num(D(value(x)))
SymbolicUtils.promote_symtype(::Differential, x) = x

is_derivative(x::Term) = operation(x) isa Differential
is_derivative(x) = false
is_derivative(x) = istree(x) ? operation(x) isa Differential : false

Base.:*(D1, D2::Differential) = D1 D2
Base.:*(D1::Differential, D2) = D1 D2
Expand All @@ -51,7 +50,7 @@ Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)
Base.hash(D::Differential, u::UInt) = hash(D.x, xor(u, 0xdddddddddddddddd))

_isfalse(occ::Bool) = occ === false
_isfalse(occ::Term) = _isfalse(operation(occ))
_isfalse(occ::Symbolic) = istree(occ) && _isfalse(operation(occ))

function occursin_info(x, expr, fail = true)
if symtype(expr) <: AbstractArray
Expand Down Expand Up @@ -85,7 +84,7 @@ function occursin_info(x, expr, fail = true)
return false
end

!istree(expr) && return false
!istree(expr) && return isequal(x, expr)
if isequal(x, expr)
true
else
Expand Down Expand Up @@ -128,11 +127,11 @@ function recursive_hasoperator(op, O)
if operation(O) isa op
return true
else
if O isa Union{Add, Mul}
if isadd(O) || ismul(O)
any(recursive_hasoperator(op), keys(O.dict))
elseif O isa Pow
elseif ispow(O)
recursive_hasoperator(op)(O.base) || recursive_hasoperator(op)(O.exp)
elseif O isa SymbolicUtils.Div
elseif isdiv(O)
recursive_hasoperator(op)(O.num) || recursive_hasoperator(op)(O.den)
else
any(recursive_hasoperator(op), arguments(O))
Expand Down Expand Up @@ -176,7 +175,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurances=nothing)

if !istree(arg)
return D(arg) # Cannot expand
elseif (op = operation(arg); isa(op, Sym))
elseif (op = operation(arg); issym(op))
inner_args = arguments(arg)
if any(isequal(D.x), inner_args)
return D(arg) # base case if any argument is directly equal to the i.v.
Expand Down Expand Up @@ -437,16 +436,18 @@ $(SIGNATURES)
A helper function for computing the Jacobian of an array of expressions with respect to
an array of variable expressions.
"""
function jacobian(ops::AbstractVector, vars::AbstractVector; simplify=false)
ops = Symbolics.scalarize(ops)
vars = Symbolics.scalarize(vars)
function jacobian(ops::AbstractVector, vars::AbstractVector; simplify=false, scalarize=true)
if scalarize
ops = Symbolics.scalarize(ops)
vars = Symbolics.scalarize(vars)
end
Num[Num(expand_derivatives(Differential(value(v))(value(O)),simplify)) for O in ops, v in vars]
end

function jacobian(ops::ArrayLike{T, 1}, vars::ArrayLike{T, 1}; simplify=false) where T
ops = scalarize(ops)
vars = scalarize(vars) # Suboptimal, but prevents wrong results on Arr for now. Arr resulting from a symbolic function will fail on this due to unknown size.
Num[Num(expand_derivatives(Differential(value(v))(value(O)),simplify)) for O in ops, v in vars]
function jacobian(ops, vars; simplify=false)
ops = vec(scalarize(ops))
vars = vec(scalarize(vars)) # Suboptimal, but prevents wrong results on Arr for now. Arr resulting from a symbolic function will fail on this due to unknown size.
jacobian(ops, vars; simplify=simplify, scalarize=false)
end

"""
Expand Down Expand Up @@ -642,7 +643,7 @@ let
error("Function of unknown linearity used: ", ~f)
end
end
@rule ~x::(x->x isa Sym) => 0]
@rule ~x::issym => 0]
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=basic_simterm))

global hessian_sparsity
Expand Down
10 changes: 6 additions & 4 deletions src/domains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ struct VarDomainPairing
domain::Domain
end

Base.:(variable::Union{Sym,Term,Num},domain::Domain) = VarDomainPairing(value(variable),domain)
Base.:(variable::Union{Sym,Term,Num},domain::Interval) = VarDomainPairing(value(variable),domain)
const DomainedVar = Union{Symbolic{<:Number}, Num}

Base.:(variable::DomainedVar,domain::Domain) = VarDomainPairing(value(variable),domain)
Base.:(variable::DomainedVar,domain::Interval) = VarDomainPairing(value(variable),domain)

# Construct Interval domain from a Tuple
Base.:(variable::Union{Sym,Term,Num},domain::NTuple{2,Real}) = VarDomainPairing(variable,Interval(domain...))
Base.:(variable::DomainedVar,domain::NTuple{2,Real}) = VarDomainPairing(variable,Interval(domain...))

# Multiple variables
Base.:(variables::NTuple{N,Union{Sym,Term,Num}},domain::Domain) where N = VarDomainPairing(value.(variables),domain)
Base.:(variables::NTuple{N,DomainedVar},domain::Domain) where N = VarDomainPairing(value.(variables),domain)

function infimum(d::AbstractInterval{T}) where T <: Num
leftendpoint(d)
Expand Down
2 changes: 1 addition & 1 deletion src/groebner_basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function symbol_to_poly(sympolys::AbstractArray)
sort!(stdsympolys, lt=(<ₑ))

pvar2sym = Bijections.Bijection{Any,Any}()
sym2term = Dict{Sym,Any}()
sym2term = Dict{BasicSymbolic,Any}()
polyforms = map(f -> PolyForm(f, pvar2sym, sym2term), stdsympolys)

# Discover common coefficient type
Expand Down
Loading

0 comments on commit 180b178

Please sign in to comment.