diff --git a/src/graph_solve.jl b/src/graph_solve.jl index 232729a..d0cb344 100644 --- a/src/graph_solve.jl +++ b/src/graph_solve.jl @@ -212,14 +212,15 @@ end #---------------------------------------------------------- function continuous_condition(out, u, t, integrator) - (;params_partitioned, state_types_val) = integrator.p + (;params_partitioned, state_types_val, connection_matrices) = integrator.p states_partitioned = to_vec_o_states(u.x, state_types_val) - _continuous_condition!(out, states_partitioned, params_partitioned, t) + _continuous_condition!(out, states_partitioned, params_partitioned, connection_matrices, t) end function _continuous_condition!(out, states_partitioned ::NTuple{Len, Any}, params_partitioned ::NTuple{Len, Any}, + connection_matrices, t) where {Len} idx = 0 @@ -227,7 +228,9 @@ function _continuous_condition!(out, if has_continuous_events(eltype(states_partitioned[i])) for j ∈ eachindex(states_partitioned[i]) idx += 1 - out[idx] = continuous_event_condition(Subsystem(states_partitioned[i][j], params_partitioned[i][j]), t) + F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) + sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) + out[idx] = continuous_event_condition(sys, t, F) end end end @@ -291,7 +294,8 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...) @nexprs $Len i -> begin if has_discrete_events(eltype(states_partitioned[i])) for j ∈ eachindex(states_partitioned[i]) - discrete_event_condition(Subsystem(states_partitioned[i][j], params_partitioned[i][j]), t) && return true + F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) + discrete_event_condition(Subsystem(states_partitioned[i][j], params_partitioned[i][j]), t, F) && return true end end end @@ -333,8 +337,8 @@ end sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j]) sview = @view states_partitioned[i][j] pview = @view params_partitioned[i][j] - if discrete_event_condition(sys, t) - F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) + F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices) + if discrete_event_condition(sys, t, F) if discrete_events_require_inputs(sys) input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices) apply_discrete_event!(integrator, sview, pview, sys, F, input) @@ -453,11 +457,36 @@ struct ForeachConnectedSubsystem{k, Len, NConn, S, P, CMs} end end -@generated function ((;l, - states_partitioned, - params_partitioned, - connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F} +# @generated function ((;l, +# states_partitioned, +# params_partitioned, +# connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F} +# quote +# @nexprs $Len i -> begin +# @nexprs $NConn nc -> begin +# M = connection_matrices[nc][k, i] +# if M isa NotConnected +# nothing +# else +# for j ∈ eachindex(states_partitioned[i]) +# @inbounds conn = M[l, j] +# if !iszero(conn) +# @inbounds states_view_dst = @view states_partitioned[i][j] +# @inbounds params_view_dst = @view params_partitioned[i][j] +# sys_dst = Subsystem(states_view_dst[], params_view_dst[]) +# f(conn, sys_dst, states_view_dst, params_view_dst) +# end +# end +# end +# end +# end +# end +# end + +@generated function Base.mapreduce(f::F, op::Op, FCS::ForeachConnectedSubsystem{k, Len, NConn}; init) where {k, Len, NConn, F, Op} quote + (;l, states_partitioned, params_partitioned, connection_matrices) = FCS + state = init @nexprs $Len i -> begin @nexprs $NConn nc -> begin M = connection_matrices[nc][k, i] @@ -470,11 +499,14 @@ end @inbounds states_view_dst = @view states_partitioned[i][j] @inbounds params_view_dst = @view params_partitioned[i][j] sys_dst = Subsystem(states_view_dst[], params_view_dst[]) - f(conn, sys_dst, states_view_dst, params_view_dst) + res = f(conn, sys_dst, states_view_dst, params_view_dst) + state = op(state, res) end end end end end - end + state + end end +(FCS::ForeachConnectedSubsystem)(f::F) where {F} = mapreduce(f, (_, _) -> nothing, FCS; init=nothing)