Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AbstractProblemType #548

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Preferences = "1.3"
Printf = "1.9"
PyCall = "1.96"
PythonCall = "0.9"
QuasiMonteCarlo = "0.3"
RCall = "0.13.18"
RecipesBase = "0.7.0, 0.8, 1.0"
RecursiveArrayTools = "2.33"
Expand All @@ -83,7 +84,6 @@ Statistics = "1"
SymbolicIndexingInterface = "0.2"
Tables = "1"
TruncatedStacktraces = "1"
QuasiMonteCarlo = "0.3"
Zygote = "0.6"
julia = "1.9"

Expand Down
32 changes: 16 additions & 16 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import ChainRulesCore
import ChainRulesCore: NoTangent, @non_differentiable

function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{
>:ChainRulesCore.HasReverseMode,
},
::typeof(getindex),
VA::ODESolution,
sym,
j::Integer)
>:ChainRulesCore.HasReverseMode,
},
::typeof(getindex),
VA::ODESolution,
sym,
j::Integer)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
if i === nothing
Expand Down Expand Up @@ -94,11 +94,11 @@ function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...)
end

function ChainRulesCore.rrule(::Type{
<:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
T11, T12,
}}, u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11,
T12}
<:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
T11, T12,
}}, u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11,
T12}
function ODESolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end
Expand All @@ -108,10 +108,10 @@ function ChainRulesCore.rrule(::Type{
end

function ChainRulesCore.rrule(::Type{
<:ODESolution{uType, tType, isinplace, P, NP, F, G, K,
ND,
}}, u,
args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
<:ODESolution{uType, tType, isinplace, P, NP, F, G, K,
ND,
}}, u,
args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
function SDESolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end
Expand All @@ -132,4 +132,4 @@ function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged
out, EnsembleSolution_adjoint
end

end
end
5 changes: 4 additions & 1 deletion ext/SciMLBasePythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ function SciMLBase.numargs(f::Py)
pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2))
end

_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? _promoting_collect(_pyconvert(x) for x in x) : pyconvert(Any, x)
function _pyconvert(x::Py)
pyisinstance(x, pybuiltins.list) ? _promoting_collect(_pyconvert(x) for x in x) :
pyconvert(Any, x)
end
_pyconvert(x::PyList) = _promoting_collect(_pyconvert(x) for x in x)
_pyconvert(x) = x

Expand Down
2 changes: 1 addition & 1 deletion ext/SciMLBaseRCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ function SciMLBase.isinplace(f::RFunction, args...; kwargs...)
false
end

end
end
54 changes: 27 additions & 27 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
using SciMLBase
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake,
getobserved, build_solution, EnsembleSolution,
NonlinearSolution, AbstractTimeseriesSolution

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -85,7 +85,7 @@
end

@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution,
::Val{:u})
::Val{:u})
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
end

Expand All @@ -107,17 +107,17 @@
}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
ODESolutionAdjoint
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)

Check warning on line 111 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L110-L111

Added lines #L110 - L111 were not covered by tests
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),

Check warning on line 114 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L114

Added line #L114 was not covered by tests
ODESolutionAdjoint
end

@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u,
args...) where
{uType, tType, isinplace, P, NP, F, G, K, ND}
args...) where
{uType, tType, isinplace, P, NP, F, G, K, ND}
function SDESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end
Expand All @@ -126,24 +126,24 @@
end

@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u,
args...) where {
T,
N,
uType,
R,
P,
A,
O,
uType2,
}
args...) where {
T,
N,
uType,
R,
P,
A,
O,
uType2,
}
function NonlinearSolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
end

@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
::Val{:u})
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
_Δ = @. ifelse(Δ === nothing, (zerou,), Δ)
Expand All @@ -153,7 +153,7 @@
end

@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution,
::Val{:u})
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
_Δ = @. ifelse(Δ === nothing, zerou, Δ)
Expand All @@ -163,7 +163,7 @@
end

@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution,
::Val{:u})
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.u)
_Δ = @. ifelse(Δ === nothing, zerou, Δ)
Expand Down Expand Up @@ -214,8 +214,8 @@
end

@adjoint function SciMLBase.responsible_map(f,
args::Union{AbstractArray, Tuple
}...)
args::Union{AbstractArray, Tuple
}...)
∇responsible_map(__context__, f, args...)
end

Expand Down
36 changes: 19 additions & 17 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -653,23 +653,23 @@ function unwrapped_f(f::FunctionWrappersWrappers.FunctionWrappersWrapper)
end

function specialization(::Union{ODEFunction{iip, specialize},
SDEFunction{iip, specialize}, DDEFunction{iip, specialize},
SDDEFunction{iip, specialize},
DAEFunction{iip, specialize},
DynamicalODEFunction{iip, specialize},
SplitFunction{iip, specialize},
DynamicalSDEFunction{iip, specialize},
SplitSDEFunction{iip, specialize},
DynamicalDDEFunction{iip, specialize},
DiscreteFunction{iip, specialize},
ImplicitDiscreteFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize},
IntegralFunction{iip, specialize},
BatchIntegralFunction{iip, specialize}}) where {iip,
specialize}
SDEFunction{iip, specialize}, DDEFunction{iip, specialize},
SDDEFunction{iip, specialize},
DAEFunction{iip, specialize},
DynamicalODEFunction{iip, specialize},
SplitFunction{iip, specialize},
DynamicalSDEFunction{iip, specialize},
SplitSDEFunction{iip, specialize},
DynamicalDDEFunction{iip, specialize},
DiscreteFunction{iip, specialize},
ImplicitDiscreteFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize},
IntegralFunction{iip, specialize},
BatchIntegralFunction{iip, specialize}}) where {iip,
specialize}
specialize
end

Expand All @@ -688,6 +688,8 @@ include("operators/common_defaults.jl")
include("symbolic_utils.jl")
include("performance_warnings.jl")

abstract type AbstractProblemType end

include("problems/discrete_problems.jl")
include("problems/implicit_discrete_problems.jl")
include("problems/steady_state_problems.jl")
Expand Down
Loading
Loading