From b509d04716824b65b8956a03788347d61f5bdebe Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 20:30:03 +0530 Subject: [PATCH] fixup! test: test implementation of SII parameter timeseries interface --- test/downstream/symbol_indexing.jl | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 5313151dc..f93aafe02 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -219,7 +219,7 @@ sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0 ps = @parameters p[1:3] = [1, 2, 3] eqs = [collect(D.(x) .~ x) D(y) ~ norm(x) * y - x[1]] -@named sys = ODESystem(eqs, t, [sts...;], [ps...;]) +@named sys = ODESystem(eqs, t, [sts...;], ps) sys = complete(sys) prob = ODEProblem(sys, [], (0, 1.0)) sol = solve(prob, Tsit5()) @@ -414,9 +414,9 @@ sol = solve(remake(prob), Tsit5()) kpidx = parameter_index(cl, kp) kpval = 1.0 -ud1idx = parameter_index(cl, Hold(ud1)) +ud1idx = timeseries_parameter_index(cl, Hold(ud1)) ud1val = [val[ud1idx.parameter_idx] for val in sol.discretes[ud1idx.timeseries_idx].u] -ud2idx = parameter_index(cl, Hold(ud2)) +ud2idx = timeseries_parameter_index(cl, Hold(ud2)) ud2val = [val[ud2idx.parameter_idx] for val in sol.discretes[ud2idx.timeseries_idx].u] ridx = parameter_index(cl, r) rval = 2.0 @@ -430,16 +430,21 @@ for (sym, val, buffer, check_inference) in [ ((kpidx, ridx), (kpval, rval), zeros(2), true), ([kp, ridx], [kpval, rval], zeros(2), true), ((kp, ridx), (kpval, rval), zeros(2), true), - ([kp, Hold(ud1)], [kpval, ud1val[end]], zeros(2), true), + # indexes are of different types, so not inferred + ([kp, Hold(ud1)], [kpval, ud1val[end]], zeros(2), false), ((kp, Hold(ud1)), (kpval, ud1val[end]), zeros(2), true), - ([kpidx, Hold(ud1)], [kpval, ud1val[end]], zeros(2), true), + ([kpidx, Hold(ud1)], [kpval, ud1val[end]], zeros(2), false), ((kpidx, Hold(ud1)), (kpval, ud1val[end]), zeros(2), true), # not technically valid, but need to test getp behavior - ([Hold(ud1) + Hold(ud2), kp], [ud1val[end] + ud2val[end], kpval], zeros(2), true), - ((Hold(ud1) + Hold(ud2), kp), (ud1val[end] + ud2val[end], kpval), zeros(2), true), + # inference broken because splatting MTKParameters is not type stable + ([Hold(ud1) + Hold(ud2), kp], [ud1val[end] + ud2val[end], kpval], zeros(2), missing), + ((Hold(ud1) + Hold(ud2), kp), (ud1val[end] + ud2val[end], kpval), zeros(2), missing), ] - getter = getp(sys, sym) - if check_inference + getter = getp(cl, sym) + if check_inference === missing + @test_broken @inferred getter(sol) + @test_broken @inferred getter(prob) + elseif check_inference @inferred getter(sol) @inferred getter(prob) end @@ -451,7 +456,7 @@ for (sym, val, buffer, check_inference) in [ buffer .= 0 getter(buffer, prob) @test buffer == collect(val) - if check_inference + if check_inference === missing || check_inference @inferred getter(buffer, sol) @inferred getter(buffer, prob) end