diff --git a/Project.toml b/Project.toml index 4cf597727..fe512f136 100644 --- a/Project.toml +++ b/Project.toml @@ -28,8 +28,15 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +SciMLBaseChainRulesCoreExt = "ChainRulesCore" + [compat] ArrayInterface = "6, 7" +ChainRulesCore = "1.15" CommonSolve = "0.2" ConstructionBase = "1" DocStringExtensions = "0.8, 0.9" @@ -51,6 +58,7 @@ TruncatedStacktraces = "1" julia = "1.6" [extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -60,4 +68,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Pkg", "SafeTestsets", "Test", "StaticArrays"] +test = ["Pkg", "SafeTestsets", "Test", "StaticArrays", "ChainRulesCore"] diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl new file mode 100644 index 000000000..e71390ef3 --- /dev/null +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -0,0 +1,35 @@ +module SciMLBaseChainRulesCoreExt + +using SciMLBase +isdefined(Base, :get_extension) ? (import ChainRulesCore) : (import ..ChainRulesCore) + +function ChainRulesCore.rrule(::Type{ + <:SciMLBase.PDETimeSeriesSolution{T, N, uType, Disc, Sol, DType, tType, domType, ivType, dvType, + P, A, + IType}}, u, + args...) where {T, N, uType, Disc, Sol, DType, tType, domType, ivType, dvType, + P, A, + IType} + function PDETimeSeriesSolutionAdjoint(ȳ) + (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) + end + + SciMLBase.PDETimeSeriesSolution{T, N, uType, Disc, Sol, DType, tType, domType, ivType, dvType, + P, A, + IType}(u, args...), PDETimeSeriesSolutionAdjoint +end + +function ChainRulesCore.rrule(::Type{ + <:SciMLBase.PDENoTimeSolution{T, N, uType, Disc, Sol, domType, ivType, dvType, P, A, + IType}}, u, + args...) where {T, N, uType, Disc, Sol, domType, ivType, dvType, P, A, + IType} + function PDENoTimeSolutionAdjoint(ȳ) + (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) + end + + SciMLBase.PDENoTimeSolution{T, N, uType, Disc, Sol, domType, ivType, dvType, P, A, + IType}(u, args...), PDENoTimeSolutionAdjoint +end + +end diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 69e4e0699..e743fcd64 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -748,6 +748,12 @@ function wrapfun_oop end function wrapfun_iip end function unwrap_fw end +@static if !isdefined(Base, :get_extension) + function __init__() + @require ChainRulesCore="d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin include("../ext/SciMLBaseChainRulesCore.jl") end + end +end + export ReturnCode export DEAlgorithm, SciMLAlgorithm, DEProblem, DEAlgorithm, DESolution, SciMLSolution