From 7ca6c736f226e8463f409beeb18300f24e31a292 Mon Sep 17 00:00:00 2001 From: Orjan Ameye Date: Wed, 20 Mar 2024 10:37:15 +0100 Subject: [PATCH] Fix `StochSystem(::CoupledODEs)` method (#20) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix `StochSystem(::CoupledODEs)` method * add iip kwarg --------- Co-authored-by: Reyk Börner --- src/StochSystem.jl | 4 ++-- src/trajectories/simulation.jl | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/StochSystem.jl b/src/StochSystem.jl index b9489a36..c0381662 100644 --- a/src/StochSystem.jl +++ b/src/StochSystem.jl @@ -68,5 +68,5 @@ to_cds(sys::StochSystem) = CoupledODEs(sys) Converts a [`CoupledODEs`](https://juliadynamics.github.io/DynamicalSystems.jl/stable/tutorial/#DynamicalSystemsBase.CoupledODEs) system into a [`StochSystem`](@ref). """ -StochSystem(ds::DynamicalSystemsBase.CoupledODEs; σ=0.0, g=idfunc, pg=nothing, Σ=I(length(get_state(ds))), process="WhiteGauss") = -StochSystem(dynamic_rule(ds), [ds.p0], get_state(ds), σ, g, pg, Σ, process) +StochSystem(ds::DynamicalSystemsBase.CoupledODEs, σ=0.0, g=idfunc, pg=nothing, Σ=I(length(get_state(ds))), process="WhiteGauss") = +StochSystem(dynamic_rule(ds), ds.p0, get_state(ds), σ, g, pg, Σ, process) \ No newline at end of file diff --git a/src/trajectories/simulation.jl b/src/trajectories/simulation.jl index 2cc9081d..f24a12af 100644 --- a/src/trajectories/simulation.jl +++ b/src/trajectories/simulation.jl @@ -24,9 +24,10 @@ function simulate(sys::StochSystem, init::State; solver=EM(), callback=nothing, progress=true, + iip=is_iip(sys.f), kwargs...) - prob = SDEProblem(sys.f, σg(sys), init, (0, tmax), p(sys), noise=stochprocess(sys)) + prob = SDEProblem{iip}(sys.f, σg(sys), init, (0, tmax), p(sys), noise=stochprocess(sys)) solve(prob, solver; dt=dt, callback=callback, progress=progress, kwargs...) end; @@ -43,7 +44,7 @@ This function integrates `sys.f` forward in time, using the [`ODEProblem`](https * `callback=nothing`: callback condition * `kwargs...`: keyword arguments for [`solve(ODEProblem)`](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/#solver_options) -For more info, see [`ODEProblem`](https://diffeq.sciml.ai/stable/types/ode_types/#SciMLBase.ODEProblem). +For more info, see [`ODEProblem`](https://diffeq.sciml.ai/stable/types/ode_types/#SciMLBase.ODEProblem). For stochastic integration, see [`simulate`](@ref). > Warning: This function has only been tested for the `Euler()` solver. @@ -53,8 +54,9 @@ function relax(sys::StochSystem, init::State; tmax=1e3, solver=Euler(), callback=nothing, + iip=is_iip(sys.f), kwargs...) - - prob = ODEProblem(sys.f, init, (0, tmax), p(sys)) + + prob = ODEProblem{iip}(sys.f, init, (0, tmax), p(sys)) solve(prob, solver; dt=dt, callback=callback, kwargs...) -end; \ No newline at end of file +end;