Skip to content

Commit

Permalink
refactor iterator.
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 committed Dec 15, 2023
1 parent c62bfe7 commit 6e02473
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,36 @@
using DormandPrince.DP5Core: DP5Solver, dopri5

mutable struct DP5Iterator{T <: Real}
struct DP5Iterator{T <: Real}
solver::DP5Solver
times::AbstractVector{T}
index::Int

function DP5Iterator(solver::DP5Solver, times::AbstractVector{T}) where {T <: Real}
new{T}(solver, times, 1)
end
end

# gets the first (t,y) so we DO NOT increment the index
# gets the first (t,y), return index which is the state
# here we choose 2 because 1 is the initial state which
# is has been returned by the iterator
function Base.iterate(dp5_iterator::DP5Iterator)
# integrate to next time
if dp5_iterator.index <= length(dp5_iterator.times)
# integrate to next time
dopri5(dp5_iterator.solver, dp5_iterator.times[dp5_iterator.index])
# return time and state
return (dp5_iterator.times[dp5_iterator.index], dp5_iterator.solver.y), nothing
else
return nothing
end

length(dp5_iterator.times) == 0 && return nothing # empty iterator
# integrate to first time
integrate(dp5_iterator.solver, first(dp5_iterator.times))
# return value and index which is the state
return (dp5_iterator.times[dp5_iterator.index], dp5_iterator.solver.y), 2
end

# don't really need the state here because we can just acccess it from the solver
# gets subsequent (t, y) so we SHOULD increment the index
function Base.iterate(dp5_iterator::DP5Iterator, state)
if dp5_iterator.index < length(dp5_iterator.times)
dp5_iterator.index += 1
# integrate to next time
dopri5(dp5_iterator.solver, dp5_iterator.times[dp5_iterator.index])
# return time and state
return (dp5_iterator.times[dp5_iterator.index], dp5_iterator.solver.y), nothing
else
return nothing
end
# gets the next (t,y), return index+! which is the updated state
function Base.iterate(dp5_iterator::DP5Iterator, index::Int)
index > length(dp5_iterator.times) && return nothing # end of iterator
# integrate to next time
integrate(dp5_iterator.solver, dp5_iterator.times[index])
# return time and state
return (dp5_iterator.times[index], dp5_iterator.solver.y), index+1
end

# 3 modes of operation for integrate
# 1. integrate(solver, time) -> state (modify solver object in place)
# 2. integrate(solver, times) -> iterator
# 3. integrate(callback, solver, times) -> vector of states with callback applied

integrate(solver::DP5Solver, time::Real) = dopri5(solver, time)
integrate(solver::DP5Solver, time::Real) = dopri5(solver, time)
integrate(solver::DP5Solver, times::AbstractVector{T}) where {T <: Real} = DP5Iterator(solver, times)


Expand Down

0 comments on commit 6e02473

Please sign in to comment.