Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor iterator. #10

Merged
merged 2 commits into from
Dec 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 17 additions & 30 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,36 @@
using DormandPrince.DP5Core: DP5Solver, dopri5

mutable struct DP5Iterator{T <: Real}
struct DP5Iterator{T <: Real}
solver::DP5Solver
times::AbstractVector{T}
index::Int

function DP5Iterator(solver::DP5Solver, times::AbstractVector{T}) where {T <: Real}
new{T}(solver, times, 1)
end
end

# gets the first (t,y) so we DO NOT increment the index
# gets the first (t,y), return index which is the state
# here we choose 2 because 1 is the initial state which
# is has been returned by the iterator
function Base.iterate(dp5_iterator::DP5Iterator)
# integrate to next time
if dp5_iterator.index <= length(dp5_iterator.times)
# integrate to next time
dopri5(dp5_iterator.solver, dp5_iterator.times[dp5_iterator.index])
# return time and state
return (dp5_iterator.times[dp5_iterator.index], dp5_iterator.solver.y), nothing
else
return nothing
end

length(dp5_iterator.times) == 0 && return nothing # empty iterator
# integrate to first time
integrate(dp5_iterator.solver, first(dp5_iterator.times))
# return value and index which is the state
return (dp5_iterator.times[1], dp5_iterator.solver.y), 2
end

# don't really need the state here because we can just acccess it from the solver
# gets subsequent (t, y) so we SHOULD increment the index
function Base.iterate(dp5_iterator::DP5Iterator, state)
if dp5_iterator.index < length(dp5_iterator.times)
dp5_iterator.index += 1
# integrate to next time
dopri5(dp5_iterator.solver, dp5_iterator.times[dp5_iterator.index])
# return time and state
return (dp5_iterator.times[dp5_iterator.index], dp5_iterator.solver.y), nothing
else
return nothing
end
# gets the next (t,y), return index+! which is the updated state
function Base.iterate(dp5_iterator::DP5Iterator, index::Int)
index > length(dp5_iterator.times) && return nothing # end of iterator
# integrate to next time
integrate(dp5_iterator.solver, dp5_iterator.times[index])
# return time and state
return (dp5_iterator.times[index], dp5_iterator.solver.y), index+1
end

# 3 modes of operation for integrate
# 1. integrate(solver, time) -> state (modify solver object in place)
# 2. integrate(solver, times) -> iterator
# 3. integrate(callback, solver, times) -> vector of states with callback applied

integrate(solver::DP5Solver, time::Real) = dopri5(solver, time)
integrate(solver::DP5Solver, time::Real) = dopri5(solver, time)
integrate(solver::DP5Solver, times::AbstractVector{T}) where {T <: Real} = DP5Iterator(solver, times)


Expand Down
Loading