Skip to content

Commit 463aef9

Browse files
authored
Merge pull request #640 from JuliaSymbolics/s/revert-revert-cc-deprecation
Revert revert dependent array convention
2 parents 539c0d2 + f251f5e commit 463aef9

9 files changed

+119
-58
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Symbolics"
22
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
33
authors = ["Shashi Gowda <[email protected]>"]
4-
version = "4.8.3"
4+
version = "4.9.0"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/arrays.jl

+15-4
Original file line numberDiff line numberDiff line change
@@ -624,11 +624,22 @@ struct ScalarizeCache end
624624

625625
function scalarize_op(f, arr, idx)
626626
if hasmetadata(arr, ScalarizeCache) && getmetadata(arr, ScalarizeCache)[] !== nothing
627-
wrap(getmetadata(arr, ScalarizeCache)[][idx...])
627+
getmetadata(arr, ScalarizeCache)[][idx...]
628628
else
629-
thing = f(scalarize.(map(wrap, arguments(arr)))...)
629+
thing = f(scalarize.(arguments(arr))...)
630+
if metadata(arr) != nothing
631+
# forward any metadata
632+
try
633+
thing = metadata(thing, metadata(arr))
634+
catch err
635+
@warn "could not attach metadata of subexpression $arr to the scalarized form at idx"
636+
end
637+
end
638+
if !hasmetadata(arr, ScalarizeCache)
639+
arr = setmetadata(arr, ScalarizeCache, Ref{Any}(nothing))
640+
end
630641
getmetadata(arr, ScalarizeCache)[] = thing
631-
wrap(thing[idx...])
642+
thing[idx...]
632643
end
633644
end
634645

@@ -645,7 +656,7 @@ end
645656
_det(x, lp) = det(x, laplace=lp)
646657

647658
function scalarize_op(f::typeof(_det), arr)
648-
det(map(wrap, collect(arguments(arr)[1])), laplace=arguments(arr)[2])
659+
unwrap(det(map(wrap, collect(arguments(arr)[1])), laplace=arguments(arr)[2]))
649660
end
650661

651662
@wrapped function LinearAlgebra.det(x::AbstractMatrix; laplace=true)

src/diff.jl

+12-6
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,13 @@ Base.hash(D::Differential, u::UInt) = hash(D.x, xor(u, 0xdddddddddddddddd))
5353
_isfalse(occ::Bool) = occ === false
5454
_isfalse(occ::Term) = _isfalse(operation(occ))
5555

56-
function occursin_info(x, expr)
56+
function occursin_info(x, expr, fail = true)
5757
if symtype(expr) <: AbstractArray
58-
error("Differentiation of expressions involving arrays and array variables is not yet supported.")
58+
if fail
59+
error("Differentiation with array expressions is not yet supported")
60+
else
61+
return occursin(x, expr)
62+
end
5963
end
6064

6165
# Allow scalarized expressions
@@ -65,11 +69,13 @@ function occursin_info(x, expr)
6569
is_scalar_indexed(operation(ex)))
6670
end
6771

72+
# x[1] == x[1] but not x[2]
6873
if is_scalar_indexed(x) && is_scalar_indexed(expr) &&
6974
isequal(first(arguments(x)), first(arguments(expr)))
7075
return isequal(operation(x), operation(expr)) &&
7176
isequal(arguments(x), arguments(expr))
7277
end
78+
7379
if is_scalar_indexed(x) && is_scalar_indexed(expr) &&
7480
!occursin(first(arguments(x)), first(arguments(expr)))
7581
return false
@@ -83,17 +89,17 @@ function occursin_info(x, expr)
8389
if isequal(x, expr)
8490
true
8591
else
86-
args = map(a->occursin_info(x, a), arguments(expr))
92+
args = map(a->occursin_info(x, a, operation(expr) !== getindex), arguments(expr))
8793
if all(_isfalse, args)
8894
return false
8995
end
9096
Term{Real}(true, args)
9197
end
9298
end
9399

94-
function occursin_info(x, expr::Sym)
95-
if symtype(expr) <: AbstractArray
96-
error("Differentiation of expressions involving arrays and array variables is not yet supported.")
100+
function occursin_info(x, expr::Sym, fail)
101+
if symtype(expr) <: AbstractArray && fail
102+
error("Differentiation of expressions involving arrays and array variables is not yet supported.")
97103
end
98104
isequal(x, expr)
99105
end

src/utils.jl

+21-7
Original file line numberDiff line numberDiff line change
@@ -111,19 +111,23 @@ function diff2term(O)
111111
else
112112
ds = nothing
113113
end
114+
d_separator = 'ˍ'
114115

115116
if ds === nothing
116117
return similarterm(O, operation(O), map(diff2term, arguments(O)), metadata=metadata(O))
117118
else
118119
oldop = operation(O)
119120
if oldop isa Sym
120121
opname = string(nameof(oldop))
122+
args = arguments(O)
121123
elseif oldop isa Term && operation(oldop) === getindex
122124
opname = string(nameof(arguments(oldop)[1]))
123-
else
124-
throw(ArgumentError("A differentiated state's operation must be a `Sym`, so states like `D(u + u)` are disallowed. Got `$oldop`."))
125+
args = arguments(O)
126+
elseif oldop == getindex
127+
args = arguments(O)
128+
opname = string(tosymbol(args[1]), "[", map(tosymbol, args[2:end])..., "]")
129+
return Sym{symtype(O)}(Symbol(opname, d_separator, ds))
125130
end
126-
d_separator = 'ˍ'
127131
newname = occursin(d_separator, opname) ? Symbol(opname, ds) : Symbol(opname, d_separator, ds)
128132
return setname(similarterm(O, rename(oldop, newname), arguments(O), metadata=metadata(O)), newname)
129133
end
@@ -162,6 +166,9 @@ function tosymbol(t::Term; states=nothing, escape=true)
162166
args = arguments(t)
163167
elseif operation(t) isa Differential
164168
term = diff2term(t)
169+
if issym(term)
170+
return nameof(term)
171+
end
165172
op = Symbol(operation(term))
166173
args = arguments(term)
167174
else
@@ -181,8 +188,6 @@ function lower_varname(var::Symbolic, idv, order)
181188
return diff2term(var)
182189
end
183190

184-
var_from_nested_derivative(x, i=0) = (missing, missing)
185-
186191
### OOPS
187192

188193
struct Unknown end
@@ -208,8 +213,17 @@ function makesubscripts(n)
208213
end
209214
end
210215

211-
var_from_nested_derivative(x::Term,i=0) = operation(x) isa Differential ? var_from_nested_derivative(arguments(x)[1], i + 1) : (x, i)
212-
var_from_nested_derivative(x::Sym,i=0) = (x, i)
216+
function var_from_nested_derivative(x,i=0)
217+
x = unwrap(x)
218+
if issym(x)
219+
(x, i)
220+
elseif istree(x)
221+
operation(x) isa Differential ?
222+
var_from_nested_derivative(first(arguments(x)), i + 1) : (x, i)
223+
else
224+
error("Not a well formed derivative expression $x")
225+
end
226+
end
213227

214228
function degree(p::Sym, sym=nothing)
215229
if sym === nothing

src/variable.jl

+48-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function scalarize_getindex(x, parent=Ref{Any}(x))
5353
scalarize_getindex(r, parent)
5454
end
5555
else
56-
xx = scalarize(x)
56+
xx = unwrap(scalarize(x))
5757
xx = metadata(xx, metadata(x))
5858
if symtype(xx) <: FnType
5959
setmetadata(CallWithMetadata(xx, metadata(xx)), GetindexParent, parent[])
@@ -125,10 +125,22 @@ function _parse_vars(macroname, type, x, transform=identity)
125125
isruntime, v = unwrap_runtime_var(v)
126126
iscall = Meta.isexpr(v, :call)
127127
isarray = Meta.isexpr(v, :ref)
128+
if iscall && Meta.isexpr(v.args[1], :ref)
129+
@warn("The variable syntax $v is deprecated. Use $(Expr(:ref, Expr(:call, v.args[1].args[1], v.args[2]), v.args[1].args[2:end]...)) instead.
130+
The former creates an array of functions, while the latter creates an array valued function.
131+
The deprecated syntax will cause an error in the next major release of Symbolics.
132+
This change will facilitate better implementation of various features of Symbolics.")
133+
end
128134
issym = v isa Symbol
129135
@assert iscall || isarray || issym "@$macroname expects a tuple of expressions or an expression of a tuple (`@$macroname x y z(t) v[1:3] w[1:2,1:4]` or `@$macroname x y z(t) v[1:3] w[1:2,1:4] k=1.0`)"
130136

131-
if iscall
137+
if isarray && Meta.isexpr(v.args[1], :call)
138+
# This is the new syntax
139+
isruntime, fname = unwrap_runtime_var(v.args[1].args[1])
140+
call_args = map(lastunwrap_runtime_var, @view v.args[1].args[2:end])
141+
size = v.args[2:end]
142+
var_name, expr = construct_dep_array_vars(macroname, fname, type′, call_args, size, val, options, transform, isruntime)
143+
elseif iscall
132144
isruntime, fname = unwrap_runtime_var(v.args[1])
133145
call_args = map(lastunwrap_runtime_var, @view v.args[2:end])
134146
var_name, expr = construct_vars(macroname, fname, type′, call_args, val, options, transform, isruntime)
@@ -144,6 +156,33 @@ function _parse_vars(macroname, type, x, transform=identity)
144156
return ex
145157
end
146158

159+
function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, prop, transform, isruntime)
160+
ndim = :($length(($(indices...),)))
161+
vname = !isruntime ? Meta.quot(lhs) : lhs
162+
if call_args[1] == :..
163+
ex = :($CallWithMetadata($Sym{$FnType{Tuple, Array{$type, $ndim}}}($vname)))
164+
else
165+
ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($vname)(map($unwrap, ($(call_args...),))...))
166+
end
167+
ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),)))
168+
169+
if val !== nothing
170+
ex = :($setdefaultval($ex, $val))
171+
end
172+
ex = setprops_expr(ex, prop, macroname, Meta.quot(lhs))
173+
#ex = :($scalarize_getindex($ex))
174+
175+
ex = :($wrap($ex))
176+
177+
if call_args[1] == :..
178+
ex = :($transform($ex))
179+
end
180+
if isruntime
181+
lhs = gensym(lhs)
182+
end
183+
lhs, :($lhs = $ex)
184+
end
185+
147186
function construct_vars(macroname, v, type, call_args, val, prop, transform, isruntime)
148187
issym = v isa Symbol
149188
isarray = isa(v, Expr) && v.head == :ref
@@ -212,7 +251,7 @@ function Base.show(io::IO, c::CallWithMetadata)
212251
end
213252

214253
function (f::CallWithMetadata)(args...)
215-
wrap(metadata(f.f(args...), metadata(f)))
254+
metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f))
216255
end
217256

218257
function construct_var(macroname, var_name, type, call_args, val, prop)
@@ -351,6 +390,12 @@ const _fail = Dict()
351390
_getname(x, _) = nameof(x)
352391
_getname(x::Symbol, _) = x
353392
function _getname(x::Symbolic, val)
393+
if istree(x) && issym(operation(x))
394+
return nameof(operation(x))
395+
end
396+
if !hasmetadata(x, Symbolics.GetindexParent) && istree(x) && operation(x) == getindex
397+
return _getname(arguments(x)[1], val)
398+
end
354399
ss = getsource(x, nothing)
355400
if ss === nothing
356401
ss = getsource(getparent(x), val)

test/arrays.jl

+10-22
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using SymbolicUtils: Sym, term, operation
2323
@test symtype(X[i, j]) == Real
2424
@test symtype(X[1, j]) == Real
2525

26-
@variables t x[1:2](t)
26+
@variables t x(t)[1:2]
2727
@test isequal(get_variables(0 ~ x[1]), [x[1]])
2828
@test Set(get_variables(2x)) == Set(collect(x)) # both array elements are present
2929
@test isequal(get_variables(2x[1]), [x[1]])
@@ -39,14 +39,14 @@ end
3939
@test isequal(unwrap(X[:, 2]), Symbolics.@arrayop((i,), XX[i, 2], term=XX[:, 2]))
4040
@test isequal(unwrap(X[:, 2:3]), Symbolics.@arrayop((i, j), XX[i, j], (j in 2:3), term=XX[:, 2:3]))
4141

42-
@variables t x[1:4](t)
42+
@variables t x(t)[1:4]
4343
@syms i::Int
44-
@test isequal(x[i], operation(unwrap(x[i]))(t))
44+
@test isequal(x[i], operation(unwrap(x))(t)[i])
4545
end
4646

4747
getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
4848
@testset "broadcast & scalarize" begin
49-
@variables A[1:5,1:3]=42 b[1:3]=[2, 3, 5] t x[1:4](t) u[1:1]
49+
@variables A[1:5,1:3]=42 b[1:3]=[2, 3, 5] t x(t)[1:4] u[1:1]
5050
AA = Symbolics.scalarize(A)
5151
bb = Symbolics.scalarize(b)
5252
@test all(isequal(42), getdef.(AA))
@@ -112,18 +112,6 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
112112
@test isequal(Symbolics.scalarize(x', (1, 1)), x[1])
113113
end
114114

115-
@testset "Parent" begin
116-
@variables t x[1:4](t)
117-
x = unwrap(x)
118-
@test Symbolics.getparent(collect(x)[1]).metadata === x.metadata
119-
end
120-
121-
@testset "Parent" begin
122-
@variables t x[1:4](t)
123-
x = unwrap(x)
124-
@test Symbolics.getparent(collect(x)[1]).metadata === x.metadata
125-
end
126-
127115
n = 2
128116
A = randn(n, n)
129117
foo(x) = A * x # a function to represent symbolically, note, if this function is defined inside the testset, it's not found by the function fun_eval = eval(fun_ex)
@@ -238,7 +226,7 @@ end
238226
n = rand(8:32)
239227
N = 2
240228

241-
@variables t u[fill(1:n, N)...](t)
229+
@variables t u(t)[fill(1:n, N)...]
242230

243231
Igrid = CartesianIndices((fill(1:n, N)...,))
244232
Iinterior = CartesianIndices((fill(2:n-1, N)...,))
@@ -287,7 +275,7 @@ end
287275

288276
@testset "Brusselator stencil" begin
289277
n = 8
290-
@variables t u[1:n, 1:n](t) v[1:n, 1:n](t)
278+
@variables t u(t)[1:n, 1:n] v(t)[1:n, 1:n]
291279

292280
brusselator_f(x, y, t) = (((x - 0.3)^2 + (y - 0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.0
293281

@@ -330,11 +318,11 @@ end
330318

331319
f, g = build_function(dtu, u, v, t, expression=Val{false})
332320
du = zeros(Num, 8, 8)
333-
#f(du, u,v,t)
334-
#@test isequal(collect(du), collect(dtu))
321+
f(du, u,v,t)
322+
@test isequal(collect(du), collect(dtu))
335323

336-
#@test isequal(collect(dtu), collect(1 .+ v .* u.^2 .- (A + 1) .* u .+ alpha .* lapu .+ s))
337-
#@test isequal(collect(dtv), collect(A .* u .- u.^2 .* v .+ alpha .* lapv))
324+
@test isequal(collect(dtu), collect(1 .+ v .* u.^2 .- (A + 1) .* u .+ alpha .* lapu .+ s))
325+
@test isequal(collect(dtv), collect(A .* u .- u.^2 .* v .+ alpha .* lapv))
338326
end
339327

340328
@testset "Partial array substitution" begin

test/diff.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ using Symbolics: value
66
# Derivatives
77
@variables t σ ρ β
88
@variables x y z
9-
@variables uu(t) uuˍt(t) v[1:3](t)
9+
@variables uu(t) uuˍt(t) v(t)[1:3]
1010
D = Differential(t)
1111
D2 = Differential(t)^2
1212
Dx = Differential(x)
1313

1414
@test Symbol(D(D(uu))) === Symbol("uuˍtt(t)")
1515
@test Symbol(D(uuˍt)) === Symbol(D(D(uu)))
16-
@test Symbol(D(v[2])) === Symbol("getindex(vˍt, 2)(t)")
16+
@test Symbol(D(v[2])) === Symbol("v(t)[2]ˍt")
1717

1818
test_equal(a, b) = @test isequal(simplify(a), simplify(b))
1919

@@ -245,7 +245,7 @@ sp_hess = Symbolics.sparsehessian(rr, X)
245245
@test isequal(map(spoly, findnz(sparse(reference_hes))[3]), map(spoly, findnz(sp_hess)[3]))
246246

247247
#96
248-
@variables t x[1:4](t)[1:4](t)
248+
@variables t x(t)[1:4] (t)[1:4]
249249
expression = sin(x[1] + x[2] + x[3] + x[4]) |> Differential(t) |> expand_derivatives
250250
expression2 = substitute(expression, Dict(collect(Differential(t).(x) .=> ẋ)))
251251
@test isequal(expression2, (ẋ[1] + ẋ[2] + ẋ[3] + ẋ[4])*cos(x[1] + x[2] + x[3] + x[4]))
@@ -300,7 +300,7 @@ end
300300
# make sure derivative(x[1](t), y) does not fail
301301
let
302302
@variables t a(t)
303-
vars = collect(@variables(x[1:1](t))[1])
303+
vars = collect(@variables(x(t)[1:1])[1])
304304
ps = collect(@variables(ps[1:1])[1])
305305
@test Symbolics.derivative(ps[1], vars[1]) == 0
306306
@test Symbolics.derivative(ps[1], a) == 0
@@ -332,7 +332,7 @@ xt2 = substitute(x, [t => t2])
332332
# 581
333333
#
334334
let
335-
@variables x[1:3](t)
335+
@variables x(t)[1:3]
336336
@test iszero(Symbolics.derivative(x[1], x[2]))
337337
end
338338

0 commit comments

Comments
 (0)