Skip to content

Commit

Permalink
fixup! test: test implementation of SII parameter timeseries interface
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 21, 2024
1 parent f296f97 commit b509d04
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
ps = @parameters p[1:3] = [1, 2, 3]
eqs = [collect(D.(x) .~ x)
D(y) ~ norm(x) * y - x[1]]
@named sys = ODESystem(eqs, t, [sts...;], [ps...;])
@named sys = ODESystem(eqs, t, [sts...;], ps)
sys = complete(sys)
prob = ODEProblem(sys, [], (0, 1.0))
sol = solve(prob, Tsit5())
Expand Down Expand Up @@ -414,9 +414,9 @@ sol = solve(remake(prob), Tsit5())

kpidx = parameter_index(cl, kp)
kpval = 1.0
ud1idx = parameter_index(cl, Hold(ud1))
ud1idx = timeseries_parameter_index(cl, Hold(ud1))
ud1val = [val[ud1idx.parameter_idx] for val in sol.discretes[ud1idx.timeseries_idx].u]
ud2idx = parameter_index(cl, Hold(ud2))
ud2idx = timeseries_parameter_index(cl, Hold(ud2))
ud2val = [val[ud2idx.parameter_idx] for val in sol.discretes[ud2idx.timeseries_idx].u]
ridx = parameter_index(cl, r)
rval = 2.0
Expand All @@ -430,16 +430,21 @@ for (sym, val, buffer, check_inference) in [
((kpidx, ridx), (kpval, rval), zeros(2), true),
([kp, ridx], [kpval, rval], zeros(2), true),
((kp, ridx), (kpval, rval), zeros(2), true),
([kp, Hold(ud1)], [kpval, ud1val[end]], zeros(2), true),
# indexes are of different types, so not inferred
([kp, Hold(ud1)], [kpval, ud1val[end]], zeros(2), false),
((kp, Hold(ud1)), (kpval, ud1val[end]), zeros(2), true),
([kpidx, Hold(ud1)], [kpval, ud1val[end]], zeros(2), true),
([kpidx, Hold(ud1)], [kpval, ud1val[end]], zeros(2), false),
((kpidx, Hold(ud1)), (kpval, ud1val[end]), zeros(2), true),
# not technically valid, but need to test getp behavior
([Hold(ud1) + Hold(ud2), kp], [ud1val[end] + ud2val[end], kpval], zeros(2), true),
((Hold(ud1) + Hold(ud2), kp), (ud1val[end] + ud2val[end], kpval), zeros(2), true),
# inference broken because splatting MTKParameters is not type stable
([Hold(ud1) + Hold(ud2), kp], [ud1val[end] + ud2val[end], kpval], zeros(2), missing),
((Hold(ud1) + Hold(ud2), kp), (ud1val[end] + ud2val[end], kpval), zeros(2), missing),
]
getter = getp(sys, sym)
if check_inference
getter = getp(cl, sym)
if check_inference === missing
@test_broken @inferred getter(sol)
@test_broken @inferred getter(prob)
elseif check_inference
@inferred getter(sol)
@inferred getter(prob)
end
Expand All @@ -451,7 +456,7 @@ for (sym, val, buffer, check_inference) in [
buffer .= 0
getter(buffer, prob)
@test buffer == collect(val)
if check_inference
if check_inference === missing || check_inference
@inferred getter(buffer, sol)
@inferred getter(buffer, prob)
end
Expand Down

0 comments on commit b509d04

Please sign in to comment.