From 59567b4bdd869482b474f58bd4122728e697cbe8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Oct 2023 17:07:16 +0530 Subject: [PATCH 1/4] fix: remove RecursiveArrayTools dependency --- Project.toml | 2 -- src/num.jl | 6 ------ test/overloads.jl | 3 --- 3 files changed, 11 deletions(-) diff --git a/Project.toml b/Project.toml index dd3bfd159..8aca8c0dc 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" @@ -64,7 +63,6 @@ MacroTools = "0.5" NaNMath = "0.3, 1" PrecompileTools = "1" RecipesBase = "1.1" -RecursiveArrayTools = "2" Reexport = "0.2, 1" ReferenceTests = "0.9" Requires = "1.1" diff --git a/src/num.jl b/src/num.jl index 6646e8e6e..7244801a7 100644 --- a/src/num.jl +++ b/src/num.jl @@ -19,9 +19,6 @@ Num(x::Num) = x # ideally this should never be called (n::Num)(args...) = Num(value(n)(map(value,args)...)) value(x) = unwrap(x) -SciMLBase.issymbollike(::Num) = true -SciMLBase.issymbollike(::SymbolicUtils.Symbolic) = true - SymbolicUtils.@number_methods( Num, Num(f(value(a))), @@ -197,6 +194,3 @@ function Base.Docs.getdoc(x::Num) end Markdown.parse(join(strings, "\n\n ")) end - -using RecursiveArrayTools -RecursiveArrayTools.issymbollike(::Union{BasicSymbolic,Num}) = true diff --git a/test/overloads.jl b/test/overloads.jl index 474218554..a32b39c54 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -237,6 +237,3 @@ for f in [<, <=, >, >=, isless] end @test_nowarn binomial(t, 1) - -using RecursiveArrayTools -@test RecursiveArrayTools.issymbollike(t) From 3afe412e0cf141578e09d9ed839228351c6eac13 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Nov 2023 20:58:18 +0530 Subject: [PATCH 2/4] refactor: overload getname from SymbolicIndexingInterface --- Project.toml | 2 ++ src/Symbolics.jl | 2 ++ src/variable.jl | 5 ++++- src/wrapper-types.jl | 4 ++-- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 8aca8c0dc..d7d14695f 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7" @@ -71,6 +72,7 @@ SciMLBase = "1.8, 2" Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "1.1" +SymbolicIndexingInterface = "0.3" SymbolicUtils = "1.4" TreeViews = "0.3" julia = "1.6" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index f9565f4dc..e0afeaa75 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -35,6 +35,8 @@ using PrecompileTools using RuntimeGeneratedFunctions using SciMLBase, IfElse using MacroTools + + using SymbolicIndexingInterface end @reexport using SymbolicUtils RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/variable.jl b/src/variable.jl index 19f101503..4d09212ab 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -406,7 +406,10 @@ end getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) -getname(x, val=_fail) = _getname(unwrap(x), val) +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() + +SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val) function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index c36fc9eca..f1007dbd3 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -17,9 +17,9 @@ function set_where(subt, supert) Expr(:where, supert, Ts...) end -getname(x::Symbol) = x +SymbolicIndexingInterface.getname(x::Symbol) = x -function getname(x::Expr) +function SymbolicIndexingInterface.getname(x::Expr) @assert x.head == :curly return x.args[1] end From aa50d3206d169ffa9f673ce7b8096d2ac2c77cda Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 Nov 2023 11:09:39 +0530 Subject: [PATCH 3/4] refactor: implement new `hasname` function from SII --- src/variable.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/variable.jl b/src/variable.jl index 4d09212ab..b2b3d3880 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -409,6 +409,12 @@ getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() +SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x)) + +function SymbolicIndexingInterface.hasname(x::Symbolic) + issym(x) || !istree(x) || istree(x) && (issym(operation(x)) || operation(x) == getindex) +end + SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val) function getparent(x, val=_fail) From 2ded1ac70ad52a932156372909b5e24ef75bf930 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Dec 2023 17:06:44 +0530 Subject: [PATCH 4/4] test: add downstream tests for SymbolicIndexingInterface --- test/runtests.jl | 9 ++++++++ ...c_indexing_interface_parameter_indexing.jl | 23 +++++++++++++++++++ test/symbolic_indexing_interface_trait.jl | 12 ++++++++++ 3 files changed, 44 insertions(+) create mode 100644 test/symbolic_indexing_interface_parameter_indexing.jl create mode 100644 test/symbolic_indexing_interface_trait.jl diff --git a/test/runtests.jl b/test/runtests.jl index e73d5ffca..27d013a61 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,15 @@ if GROUP == "All" || GROUP == "Core" @safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end end +if GROUP == "All" || GROUP == "Core" || GROUP == "SymbolicIndexingInterface" + @safetestset "SymbolicIndexingInterface Trait Test" begin + include("symbolic_indexing_interface_trait.jl") + end + @safetestset "SymbolicIndexingInterface Parameter Indexing Test" begin + include("symbolic_indexing_interface_parameter_indexing.jl") + end +end + if GROUP == "Downstream" activate_downstream_env() #@time @safetestset "ParameterizedFunctions MATLABDiffEq Regression Test" begin include("downstream/ParameterizedFunctions_MATLAB.jl") end diff --git a/test/symbolic_indexing_interface_parameter_indexing.jl b/test/symbolic_indexing_interface_parameter_indexing.jl new file mode 100644 index 000000000..a05831484 --- /dev/null +++ b/test/symbolic_indexing_interface_parameter_indexing.jl @@ -0,0 +1,23 @@ +using SymbolicIndexingInterface +using Symbolics + +struct FakeIntegrator{P} + p::P +end + +SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys +SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p + +@variables a[1:2] b +sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) +p = [1.0, 2.0, 3.0] +fi = FakeIntegrator(copy(p)) +for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))] + get = getp(sys, sym) + set! = setp(sys, sym) + true_value = i isa Tuple ? getindex.((p,), i) : p[i] + @test get(fi) == true_value + set!(fi, 0.5 .* i) + @test get(fi) == 0.5 .* i + set!(fi, true_value) +end diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl new file mode 100644 index 000000000..52d1579ae --- /dev/null +++ b/test/symbolic_indexing_interface_trait.jl @@ -0,0 +1,12 @@ +using Symbolics +using SymbolicUtils +using SymbolicIndexingInterface + +@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num]) .== + (ScalarSymbolic(),)) +@test symbolic_type(Symbolics.Arr) == ArraySymbolic() +@variables x +@test symbolic_type(x) == ScalarSymbolic() +@variables y[1:3] +@test symbolic_type(y) == ArraySymbolic() +@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))