diff --git a/Project.toml b/Project.toml index f70277d8..a06cf5a3 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.4.12" [deps] Buckets = "3235f445-51d8-4100-901d-5b23398ac3ab" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/src/Gradus.jl b/src/Gradus.jl index fa1bb711..b155c049 100644 --- a/src/Gradus.jl +++ b/src/Gradus.jl @@ -6,6 +6,8 @@ using LinearAlgebra: ×, ⋅, norm, det, dot, inv using Parameters +import DiffEqBase + using SciMLBase using OrdinaryDiffEq using DiffEqCallbacks @@ -158,7 +160,7 @@ number of geodesics. Also used to dispatch different tracing problems. abstract type AbstractTrace end """ - AbstractIntegrationParameters + AbstractIntegrationParameters{M} Parameters that are made available at each step of the integration, that need not be constant. For example, the turning points or withing-geometry flags. @@ -167,13 +169,28 @@ The integration parameters should track which spacetime `M` they are parameters Integration parameters must implement - [`set_status_code!`](@ref) - [`get_status_code`](@ref) +- [`get_metric`](@ref) For more complex parameters, may also optionally implement - [`update_integration_parameters!`](@ref) See the documentation of each of the above functions for details of their operation. """ -abstract type AbstractIntegrationParameters end +abstract type AbstractIntegrationParameters{M<:AbstractMetric} end + +# type alias, since this is often used +const MutStatusCode = MVector{1,StatusCodes.T} + +# TODO: temporary fix for https://github.com/SciML/DiffEqBase.jl/issues/918 +function DiffEqBase.anyeltypedual( + ::AbstractIntegrationParameters{<:AbstractMetric{T}}, +) where {T} + if T <: ForwardDiff.Dual + T + else + Any + end +end """ update_integration_parameters!(old::AbstractIntegrationParameters, new::AbstractIntegrationParameters) @@ -197,7 +214,7 @@ end Update the status [`StatusCodes`](@ref) in `p` with `status`. """ -set_status_code!(params::AbstractIntegrationParameters, status::StatusCodes.T) = +set_status_code!(params::AbstractIntegrationParameters, ::StatusCodes.T) = error("Not implemented for $(typeof(params))") """ @@ -208,6 +225,14 @@ Return the status [`StatusCodes`](@ref) in `status`. get_status_code(params::AbstractIntegrationParameters) = error("Not implemented for $(typeof(params))") +""" + get_metric(p::AbstractIntegrationParameters{M})::M where {M} + +Return the [`AbstractMetric`](@ref) `m::M` for which the integration parameters +have been specialised. +""" +get_metric(params::AbstractIntegrationParameters) = + error("Not implemented for $(typeof(params))") """ abstract type AbstractGeodesicPoint diff --git a/src/metrics/kerr-newman-ad.jl b/src/metrics/kerr-newman-ad.jl index 25200731..8d78499e 100644 --- a/src/metrics/kerr-newman-ad.jl +++ b/src/metrics/kerr-newman-ad.jl @@ -78,10 +78,11 @@ function geodesic_ode_problem( end function f(u::SVector{8,T}, p, λ) where {T} @inbounds let x = SVector{4,T}(@view(u[1:4])), v = SVector{4,T}(@view(u[5:8])) - dv = SVector{4,T}(geodesic_equation(m, x, v)) + _m = get_metric(p) + dv = SVector{4,T}(geodesic_equation(_m, x, v)) # add maxwell part dvf = if !(trace.q ≈ 0.0) - F = faraday_tensor(m, x) + F = faraday_tensor(_m, x) q_μ * (F * v) else zero(SVector{4,T}) @@ -95,7 +96,7 @@ function geodesic_ode_problem( f, u_init, time_domain, - IntegrationParameters(StatusCodes.NoStatus); + IntegrationParameters(m, StatusCodes.NoStatus); callback = callback, ) end diff --git a/src/precompile.jl b/src/precompile.jl index 33b8945c..aa713aae 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,3 +1,11 @@ +Base.precompile( + Tuple{ + typeof(_second_order_ode_f), + SVector{8,Float64}, + IntegrationParameters{KerrMetric{Float64}}, + Float64, + }, +) # time: 5.720331 Base.precompile( Tuple{ typeof(lineprofile), @@ -5,7 +13,7 @@ Base.precompile( SVector{4,Float64}, GeometricThinDisc{Float64}, }, -) # time: 4.8591228 +) # time: 2.5695212 Base.precompile( Tuple{ typeof(tracegeodesics), @@ -14,7 +22,119 @@ Base.precompile( SVector{4,Float64}, Vararg{Any}, }, -) # time: 1.1855857 +) # time: 1.4936496 +let fbody = try + __lookup_kwbody__( + which(tracegeodesics, (KerrMetric{Float64}, SVector{4,Float64}, Vararg{Any})), + ) + catch missing + end + if !ismissing(fbody) + precompile( + fbody, + ( + Float64, + Float64, + TraceGeodesic{Float64}, + Base.Pairs{Symbol,Union{},Tuple{},NamedTuple{(),Tuple{}}}, + typeof(tracegeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + Vararg{Any}, + ), + ) + end +end # time: 1.0984381 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:n_samples,),Tuple{Int64}}, + typeof(emissivity_profile), + KerrMetric{Float64}, + GeometricThinDisc{Float64}, + LampPostModel{Float64}, + }, +) # time: 0.9623503 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:n_samples,),Tuple{Int64}}, + typeof(tracecorona), + KerrMetric{Float64}, + GeometricThinDisc{Float64}, + LampPostModel{Float64}, + }, +) # time: 0.4676826 +let fbody = try + __lookup_kwbody__( + which(tracegeodesics, (KerrMetric{Float64}, SVector{4,Float64}, Vararg{Any})), + ) + catch missing + end + if !ismissing(fbody) + precompile( + fbody, + ( + Float64, + Float64, + TraceRadiativeTransfer{Float64}, + Base.Pairs{Symbol,Union{},Tuple{},NamedTuple{(),Tuple{}}}, + typeof(tracegeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + Vararg{Any}, + ), + ) + end +end # time: 0.3491159 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:n_samples,),Tuple{Int64}}, + typeof(tracegeodesics), + KerrMetric{Float64}, + LampPostModel{Float64}, + GeometricThinDisc{Float64}, + Vararg{Any}, + }, +) # time: 0.3443357 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:image_width, :image_height),Tuple{Int64,Int64}}, + typeof(rendergeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + GeometricThinDisc{Float64}, + Vararg{Any}, + }, +) # time: 0.19957814 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:trace,),Tuple{TraceRadiativeTransfer{Float64}}}, + typeof(tracegeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + SVector{4,Float64}, + Vararg{Any}, + }, +) # time: 0.13094269 +Base.precompile(Tuple{Type{RadialDiscProfile},Vector{Float64},Vector{Float64},Vector{Any}}) # time: 0.08506817 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{ + (:trace, :image_width, :image_height), + Tuple{TraceRadiativeTransfer{Float64},Int64,Int64}, + }, + typeof(rendergeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + GeometricThinDisc{Float64}, + Vararg{Any}, + }, +) # time: 0.0841068 let fbody = try __lookup_kwbody__( which( @@ -47,59 +167,54 @@ let fbody = try ), ) end -end # time: 0.9053782 +end # time: 0.036170956 Base.precompile( Tuple{ - typeof(Core.kwcall), - NamedTuple{ - (:trace, :image_width, :image_height), - Tuple{TraceRadiativeTransfer{Float64},Int64,Int64}, - }, - typeof(rendergeodesics), + typeof(tracegeodesics), KerrMetric{Float64}, - SVector{4,Float64}, - GeometricThinDisc{Float64}, + Vector{SVector{4,Float64}}, + Vector{SVector{4,Float64}}, Vararg{Any}, }, -) # time: 0.11555993 +) # time: 0.033662505 Base.precompile( Tuple{ - typeof(Core.kwcall), - NamedTuple{(:trace,),Tuple{TraceRadiativeTransfer{Float64}}}, - typeof(tracegeodesics), + typeof(tracing_configuration), + TraceGeodesic{Float64}, KerrMetric{Float64}, SVector{4,Float64}, SVector{4,Float64}, - Vararg{Any}, + GeometricThinDisc{Float64}, + Float64, }, -) # time: 0.04722405 +) # time: 0.030030003 Base.precompile( Tuple{ - typeof(tracegeodesics), + typeof(tracing_configuration), + TraceRadiativeTransfer{Float64}, KerrMetric{Float64}, - Vector{SVector{4,Float64}}, - Vector{SVector{4,Float64}}, - Vararg{Any}, + SVector{4,Float64}, + SVector{4,Float64}, + GeometricThinDisc{Float64}, + Float64, }, -) # time: 0.034065828 +) # time: 0.01967225 Base.precompile( Tuple{ typeof(Core.kwcall), - NamedTuple{(:n_samples,),Tuple{Int64}}, - typeof(tracegeodesics), + NamedTuple{ + (:trajectories, :save_on, :ensemble), + Tuple{Int64,Bool,EnsembleEndpointThreads}, + }, + typeof(tracing_configuration), + TraceGeodesic{Float64}, KerrMetric{Float64}, - LampPostModel{Float64}, + SVector{4,Float64}, + Function, GeometricThinDisc{Float64}, - Vararg{Any}, - }, -) # time: 0.025394669 -Base.precompile( - Tuple{ - typeof(update_integration_parameters!), - RadiativeTransferIntegrationParameters{Vector{Bool}}, - RadiativeTransferIntegrationParameters{Vector{Bool}}, + Float64, }, -) # time: 0.007881916 +) # time: 0.010621044 Base.precompile( Tuple{ typeof(Core.kwcall), @@ -115,4 +230,50 @@ Base.precompile( GeometricThinDisc{Float64}, Float64, }, -) # time: 0.007585501 +) # time: 0.008629832 +Base.precompile( + Tuple{ + typeof(tracing_configuration), + TraceGeodesic{Float64}, + KerrMetric{Float64}, + Vector{SVector{4,Float64}}, + Vector{SVector{4,Float64}}, + GeometricThinDisc{Float64}, + Float64, + }, +) # time: 0.007926913 +Base.precompile( + Tuple{ + typeof(update_integration_parameters!), + RadiativeTransferIntegrationParameters{KerrMetric{Float64},Vector{Bool}}, + RadiativeTransferIntegrationParameters{KerrMetric{Float64},Vector{Bool}}, + }, +) # time: 0.007825662 +Base.precompile( + Tuple{typeof(point_source_equitorial_disc_emissivity),Float64,Any,Float64,Float64}, +) # time: 0.006162999 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:gtol,),Tuple{Float64}}, + typeof(distance_to_disc), + GeometricThinDisc{Float64}, + SVector{9,Float64}, + }, +) # time: 0.003949623 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:gtol,),Tuple{Float64}}, + typeof(distance_to_disc), + GeometricThinDisc{Float64}, + SVector{8,Float64}, + }, +) # time: 0.001979128 +Base.precompile( + Tuple{ + typeof(set_status_code!), + IntegrationParameters{KerrMetric{Float64}}, + Gradus.StatusCodes.T, + }, +) # time: 0.001029539 diff --git a/src/solution-processing.jl b/src/solution-processing.jl index 283264a2..2210f986 100644 --- a/src/solution-processing.jl +++ b/src/solution-processing.jl @@ -128,5 +128,6 @@ function unpack_solution_full( end unpack_solution(gp::AbstractGeodesicPoint) = gp -unpack_solution(sol::SciMLBase.AbstractODESolution) = unpack_solution(sol.prob.f.f.m, sol) +unpack_solution(sol::SciMLBase.AbstractODESolution) = + unpack_solution(get_metric(sol.prob.p), sol) unpack_solution(simsol::SciMLBase.AbstractEnsembleSolution) = map(unpack_solution, simsol.u) diff --git a/src/tracing/geodesic-problem.jl b/src/tracing/geodesic-problem.jl index 9dc4a455..51a615a9 100644 --- a/src/tracing/geodesic-problem.jl +++ b/src/tracing/geodesic-problem.jl @@ -1,10 +1,15 @@ -mutable struct IntegrationParameters <: AbstractIntegrationParameters - status::StatusCodes.T + +struct IntegrationParameters{M} <: AbstractIntegrationParameters{M} + metric::M + status::MutStatusCode + IntegrationParameters(metric::M, status) where {M} = + new{M}(metric, MutStatusCode(status)) end set_status_code!(params::IntegrationParameters, status::StatusCodes.T) = - params.status = status -get_status_code(params::IntegrationParameters) = params.status + params.status[1] = status +get_status_code(params::IntegrationParameters) = params.status[1] +get_metric(params::IntegrationParameters) = params.metric """ geodesic_ode_problem( @@ -14,7 +19,7 @@ get_status_code(params::IntegrationParameters) = params.status vel, time_domain::Tuple, callback - ) + Returns an `OrdinaryDiffEq.ODEProblem{false}`, specifying the ODE problem to be solved. The precise problem depends on the [`AbstractTrace`](@ref) and [`AbstractMetric`](@ref) defined. @@ -62,23 +67,22 @@ function geodesic_ode_problem( time_domain, callback, ) - function geodesic_ode_f(u::SVector{8,T}, p, λ) where {T} - @inbounds let x = SVector{4,T}(@view(u[1:4])), v = SVector{4,T}(@view(u[5:8])) - dv = SVector{4,T}(geodesic_equation(m, x, v)) - vcat(v, dv) - end - end - u_init = vcat(pos, vel) ODEProblem{false}( - geodesic_ode_f, + _second_order_ode_f, u_init, time_domain, - IntegrationParameters(StatusCodes.NoStatus); + IntegrationParameters(m, StatusCodes.NoStatus); callback = callback, ) end +function _second_order_ode_f(u::SVector{8,T}, p, λ) where {T} + @inbounds let x = SVector{4,T}(@view(u[1:4])), v = SVector{4,T}(@view(u[5:8])) + dv = SVector{4,T}(geodesic_equation(get_metric(p), x, v)) + vcat(v, dv) + end +end """ assemble_tracing_problem(trace::AbstractTrace, config::TracingConfiguration) diff --git a/src/tracing/method-implementations/first-order.jl b/src/tracing/method-implementations/first-order.jl index 39f1f63d..0fc6f6ce 100644 --- a/src/tracing/method-implementations/first-order.jl +++ b/src/tracing/method-implementations/first-order.jl @@ -84,21 +84,23 @@ of motion in `p`. """ four_velocity(u, m::AbstractFirstOrderMetric, p) = error("Not implmented for $(typeof(m)).") -mutable struct FirstOrderIntegrationParameters{T} <: AbstractIntegrationParameters +mutable struct FirstOrderIntegrationParameters{M,T} <: AbstractIntegrationParameters{M} + metric::M L::T Q::T r::Int θ::Int changes::Vector{T} status::StatusCodes.T -end + FirstOrderIntegrationParameters(m::M, L, Q, sign_θ, ::Type{T}) where {M,T} = + new{M,T}(m, L, Q, -1, sign_θ, [0.0, 0.0], StatusCodes.NoStatus) -make_parameters(L, Q, sign_θ, ::Type{T}) where {T} = - FirstOrderIntegrationParameters{T}(L, Q, -1, sign_θ, [0.0, 0.0], StatusCodes.NoStatus) +end set_status_code!(params::FirstOrderIntegrationParameters, status::StatusCodes.T) = params.status = status get_status_code(params::FirstOrderIntegrationParameters) = params.status +get_metric(params::FirstOrderIntegrationParameters) = params.metric function update_integration_parameters!( p::FirstOrderIntegrationParameters, @@ -112,6 +114,10 @@ function update_integration_parameters!( p end +function _first_order_ode_f(u, p, λ) + SVector(four_velocity(u, get_metric(p), p)...) +end + function geodesic_ode_problem( ::TraceGeodesic, m::AbstractFirstOrderMetric{T}, @@ -122,13 +128,12 @@ function geodesic_ode_problem( ) where {S,T} L, Q = calc_lq(m, pos, vel) ODEProblem{false}( + _first_order_ode_f, pos, time_domain, - make_parameters(L, Q, vel[2], T); + FirstOrderIntegrationParameters(m, L, Q, vel[2], T); callback = callback, - ) do u, p, λ - SVector(four_velocity(u, m, p)...) - end + ) end convert_velocity_type(::StaticVector{S,T}, v) where {S,T} = convert(SVector{S,T}, v) diff --git a/src/tracing/radiative-transfer-problem.jl b/src/tracing/radiative-transfer-problem.jl index abd8d662..76e616df 100644 --- a/src/tracing/radiative-transfer-problem.jl +++ b/src/tracing/radiative-transfer-problem.jl @@ -26,19 +26,30 @@ end absorption_coefficient(m::AbstractMetric, d::AbstractAccretionGeometry, x, ν) = 0.0 emissivity_coefficient(m::AbstractMetric, d::AbstractAccretionGeometry, x, ν) = 0.0 -mutable struct RadiativeTransferIntegrationParameters{V} <: AbstractIntegrationParameters - status::StatusCodes.T +struct RadiativeTransferIntegrationParameters{M,V} <: AbstractIntegrationParameters{M} + metric::M + status::MutStatusCode within_geometry::V + RadiativeTransferIntegrationParameters( + metric::M, + status, + within_geometry::V, + ) where {M,V} = new{M,V}(metric, MutStatusCode(status), within_geometry) end -function _radiative_transfer_integration_parameters(status::StatusCodes.T, geometry) +function _radiative_transfer_integration_parameters( + metric::AbstractMetric, + status::StatusCodes.T, + geometry, +) within_geometry = map(!is_finite_disc, geometry) - RadiativeTransferIntegrationParameters(status, within_geometry) + RadiativeTransferIntegrationParameters(metric, status, within_geometry) end set_status_code!(params::RadiativeTransferIntegrationParameters, status::StatusCodes.T) = - params.status = status -get_status_code(params::RadiativeTransferIntegrationParameters) = params.status + params.status[1] = status +get_status_code(params::RadiativeTransferIntegrationParameters) = params.status[1] +get_metric(params::RadiativeTransferIntegrationParameters) = params.metric function update_integration_parameters!( @@ -127,10 +138,10 @@ function radiative_transfer_ode_problem( function f(u::SVector{9,T}, p, λ) where {T} @inbounds let x = SVector{4,T}(u[1:4]), k = SVector{4,T}(u[5:8]), I = u[9] - - dk = SVector{4,T}(geodesic_equation(m, x, k)) + _m = get_metric(p) + dk = SVector{4,T}(geodesic_equation(_m, x, k)) dI = _intensity_delta( - m, + _m, x, k, geometry, @@ -151,7 +162,7 @@ function radiative_transfer_ode_problem( f, u_init, time_domain, - _radiative_transfer_integration_parameters(StatusCodes.NoStatus, geometry); + _radiative_transfer_integration_parameters(m, StatusCodes.NoStatus, geometry); callback = callback, ) end