Skip to content

Commit

Permalink
fix: various bug and test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 31, 2024
1 parent 152e9c2 commit af37ae8
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
12 changes: 8 additions & 4 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
sym = unwrap(sym)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return sym isa ParameterIndex || is_parameter(ic, sym) ||

Check warning on line 429 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L429

Added line #L429 was not covered by tests
istree(sym) && operation(sym) === getindex &&
istree(sym) &&
operation(sym) === getindex &&
is_parameter(ic, first(arguments(sym)))
end
if unwrap(sym) isa Int
Expand Down Expand Up @@ -462,10 +463,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
end
elseif istree(sym) && operation(sym) === getindex &&
(idx = parameter_index(ic, first(arguments(sym)))) !== nothing
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == nothing
if idx.portion isa SciMLStructures.Discrete &&

Check warning on line 466 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L466

Added line #L466 was not covered by tests
idx.idx[2] == idx.idx[3] == nothing
return nothing

Check warning on line 468 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L468

Added line #L468 was not covered by tests
else
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
ParameterIndex(

Check warning on line 470 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L470

Added line #L470 was not covered by tests
idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
end
else
nothing
Expand All @@ -485,7 +488,8 @@ end
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
idx = parameter_index(ic, sym)
if idx === nothing || idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
if idx === nothing ||

Check warning on line 491 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L490-L491

Added lines #L490 - L491 were not covered by tests
idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
return nothing

Check warning on line 493 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L493

Added line #L493 was not covered by tests
else
return idx

Check warning on line 495 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L495

Added line #L495 was not covered by tests
Expand Down
7 changes: 4 additions & 3 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ function MTKParameters(
end
end
tunable_buffer = narrow_buffer_type.(tunable_buffer)
disc_buffer = narrow_buffer_type.(disc_buffer)
disc_buffer = broadcast.(narrow_buffer_type, disc_buffer)

Check warning on line 135 in src/systems/parameter_buffer.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/parameter_buffer.jl#L135

Added line #L135 was not covered by tests
const_buffer = narrow_buffer_type.(const_buffer)
nonnumeric_buffer = narrow_buffer_type.(nonnumeric_buffer)

Expand All @@ -149,7 +149,8 @@ function MTKParameters(
oop, iip = build_function(dep_exprs, p...)
update_function_iip, update_function_oop = RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(iip),
RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(oop)
update_function_iip(ArrayPartition(dep_buffer), tunable_buffer..., disc_buffer...,
update_function_iip(ArrayPartition(dep_buffer), tunable_buffer...,

Check warning on line 152 in src/systems/parameter_buffer.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/parameter_buffer.jl#L152

Added line #L152 was not covered by tests
Iterators.flatten(disc_buffer)...,
const_buffer..., nonnumeric_buffer..., dep_buffer...)
dep_buffer = narrow_buffer_type.(dep_buffer)
else
Expand Down Expand Up @@ -442,7 +443,7 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
@set! newbuf.dependent = narrow_buffer_type_and_fallback_undefs.(
oldbuf.dependent,
split_into_buffers(
newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(false)))
newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(0)))
end
return newbuf
end
Expand Down
2 changes: 1 addition & 1 deletion test/mtkparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ ps = MTKParameters(sys,
yd2 => 2.0 + Sample(ssc)(x), Sample(t, dt)(x) => x,
Sample(ssc)(x) => x, Hold(yd1) => yd1, Hold(yd2) => yd2],
[x => 3.0])
@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}}
tsidx1 = timeseries_parameter_index(sys, flag).timeseries_idx
tsidx2 = 3 - tsidx1
@test SciMLBase.get_saveable_values(ps, tsidx1).x isa Tuple{Vector{Float64}, BitVector}
@test length(ps.discrete[tsidx1][1]) == 3
@test length(ps.discrete[tsidx1][2]) == 1
@test length(ps.discrete[tsidx2][1]) == 3
Expand Down
3 changes: 2 additions & 1 deletion test/symbolic_indexing_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ using SciMLStructures: Tunable
@test parameter_index(odesys, a) isa ParameterIndex{Tunable, Tuple{Int, Int}}
@test parameter_index(odesys, b) == parameter_index(odesys, :b)
@test parameter_index(odesys, b) isa ParameterIndex{Tunable, Tuple{Int, Int}}
@test parameter_index.((odesys,), [x, y, t, ParameterIndex(Tunable(), (1, 1)), :x, :y,]) ==
@test parameter_index.(
(odesys,), [x, y, t, ParameterIndex(Tunable(), (1, 1)), :x, :y]) ==
[nothing, nothing, nothing, ParameterIndex(Tunable(), (1, 1)), nothing, nothing]
@test isequal(parameter_symbols(odesys), [a, b])
@test all(is_independent_variable.((odesys,), [t, :t]))
Expand Down

0 comments on commit af37ae8

Please sign in to comment.