Skip to content

Commit

Permalink
Merge pull request #1005 from AayushSabharwal/as/indexing-rework
Browse files Browse the repository at this point in the history
fix: remove RecursiveArrayTools dependency, use SymbolicIndexingInterface
  • Loading branch information
ChrisRackauckas authored Dec 11, 2023
2 parents 4b3081e + 2ded1ac commit 49172c0
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 14 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
Expand All @@ -35,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[weakdeps]
Expand Down Expand Up @@ -63,7 +63,6 @@ MacroTools = "0.5"
NaNMath = "0.3, 1"
PrecompileTools = "1"
RecipesBase = "1.1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
ReferenceTests = "0.9"
Requires = "1.1"
Expand All @@ -72,6 +71,7 @@ SciMLBase = "1.8, 2"
Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "1.1"
SymbolicIndexingInterface = "0.3"
SymbolicUtils = "1.4"
julia = "1.6"

Expand Down
2 changes: 2 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using PrecompileTools
using RuntimeGeneratedFunctions
using SciMLBase, IfElse
using MacroTools

using SymbolicIndexingInterface
end
@reexport using SymbolicUtils
RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down
6 changes: 0 additions & 6 deletions src/num.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ Num(x::Num) = x # ideally this should never be called
(n::Num)(args...) = Num(value(n)(map(value,args)...))
value(x) = unwrap(x)

SciMLBase.issymbollike(::Num) = true
SciMLBase.issymbollike(::SymbolicUtils.Symbolic) = true

SymbolicUtils.@number_methods(
Num,
Num(f(value(a))),
Expand Down Expand Up @@ -197,6 +194,3 @@ function Base.Docs.getdoc(x::Num)
end
Markdown.parse(join(strings, "\n\n "))
end

using RecursiveArrayTools
RecursiveArrayTools.issymbollike(::Union{BasicSymbolic,Num}) = true
11 changes: 10 additions & 1 deletion src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,16 @@ end

getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val)

getname(x, val=_fail) = _getname(unwrap(x), val)
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic()
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic()

SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x))

function SymbolicIndexingInterface.hasname(x::Symbolic)
issym(x) || !istree(x) || istree(x) && (issym(operation(x)) || operation(x) == getindex)
end

SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val)

function getparent(x, val=_fail)
maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing)
Expand Down
4 changes: 2 additions & 2 deletions src/wrapper-types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ function set_where(subt, supert)
Expr(:where, supert, Ts...)
end

getname(x::Symbol) = x
SymbolicIndexingInterface.getname(x::Symbol) = x

function getname(x::Expr)
function SymbolicIndexingInterface.getname(x::Expr)
@assert x.head == :curly
return x.args[1]
end
Expand Down
3 changes: 0 additions & 3 deletions test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,3 @@ for f in [<, <=, >, >=, isless]
end

@test_nowarn binomial(t, 1)

using RecursiveArrayTools
@test RecursiveArrayTools.issymbollike(t)
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
end

if GROUP == "All" || GROUP == "Core" || GROUP == "SymbolicIndexingInterface"
@safetestset "SymbolicIndexingInterface Trait Test" begin
include("symbolic_indexing_interface_trait.jl")
end
@safetestset "SymbolicIndexingInterface Parameter Indexing Test" begin
include("symbolic_indexing_interface_parameter_indexing.jl")
end
end

if GROUP == "Downstream"
activate_downstream_env()
#@time @safetestset "ParameterizedFunctions MATLABDiffEq Regression Test" begin include("downstream/ParameterizedFunctions_MATLAB.jl") end
Expand Down
23 changes: 23 additions & 0 deletions test/symbolic_indexing_interface_parameter_indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using SymbolicIndexingInterface
using Symbolics

struct FakeIntegrator{P}
p::P
end

SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p

@variables a[1:2] b
sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(copy(p))
for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))]
get = getp(sys, sym)
set! = setp(sys, sym)
true_value = i isa Tuple ? getindex.((p,), i) : p[i]
@test get(fi) == true_value
set!(fi, 0.5 .* i)
@test get(fi) == 0.5 .* i
set!(fi, true_value)
end
12 changes: 12 additions & 0 deletions test/symbolic_indexing_interface_trait.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Symbolics
using SymbolicUtils
using SymbolicIndexingInterface

@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num]) .==
(ScalarSymbolic(),))
@test symbolic_type(Symbolics.Arr) == ArraySymbolic()
@variables x
@test symbolic_type(x) == ScalarSymbolic()
@variables y[1:3]
@test symbolic_type(y) == ArraySymbolic()
@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))

0 comments on commit 49172c0

Please sign in to comment.