From 91204fbb992a080d690773d73427ea11e35e66fc Mon Sep 17 00:00:00 2001 From: fjebaker Date: Sun, 30 Jul 2023 11:53:44 +0100 Subject: [PATCH 1/2] feat: spacetime concious integration parameters The integration parameters now hold the metric for the purpose of both dispatching and unpacking geodesic problems and solutions. This has the benefit that the integration function can now be pure, without any closure captures, which hopefully should aid the GPU implementation. Also conveys a slight performance improvement of around 10%. --- src/Gradus.jl | 31 ++- src/metrics/kerr-newman-ad.jl | 7 +- src/precompile.jl | 231 +++++++++++++++--- src/solution-processing.jl | 3 +- src/tracing/geodesic-problem.jl | 32 +-- .../method-implementations/first-order.jl | 21 +- src/tracing/radiative-transfer-problem.jl | 31 ++- 7 files changed, 282 insertions(+), 74 deletions(-) diff --git a/src/Gradus.jl b/src/Gradus.jl index fa1bb711..b155c049 100644 --- a/src/Gradus.jl +++ b/src/Gradus.jl @@ -6,6 +6,8 @@ using LinearAlgebra: ×, ⋅, norm, det, dot, inv using Parameters +import DiffEqBase + using SciMLBase using OrdinaryDiffEq using DiffEqCallbacks @@ -158,7 +160,7 @@ number of geodesics. Also used to dispatch different tracing problems. abstract type AbstractTrace end """ - AbstractIntegrationParameters + AbstractIntegrationParameters{M} Parameters that are made available at each step of the integration, that need not be constant. For example, the turning points or withing-geometry flags. @@ -167,13 +169,28 @@ The integration parameters should track which spacetime `M` they are parameters Integration parameters must implement - [`set_status_code!`](@ref) - [`get_status_code`](@ref) +- [`get_metric`](@ref) For more complex parameters, may also optionally implement - [`update_integration_parameters!`](@ref) See the documentation of each of the above functions for details of their operation. """ -abstract type AbstractIntegrationParameters end +abstract type AbstractIntegrationParameters{M<:AbstractMetric} end + +# type alias, since this is often used +const MutStatusCode = MVector{1,StatusCodes.T} + +# TODO: temporary fix for https://github.com/SciML/DiffEqBase.jl/issues/918 +function DiffEqBase.anyeltypedual( + ::AbstractIntegrationParameters{<:AbstractMetric{T}}, +) where {T} + if T <: ForwardDiff.Dual + T + else + Any + end +end """ update_integration_parameters!(old::AbstractIntegrationParameters, new::AbstractIntegrationParameters) @@ -197,7 +214,7 @@ end Update the status [`StatusCodes`](@ref) in `p` with `status`. """ -set_status_code!(params::AbstractIntegrationParameters, status::StatusCodes.T) = +set_status_code!(params::AbstractIntegrationParameters, ::StatusCodes.T) = error("Not implemented for $(typeof(params))") """ @@ -208,6 +225,14 @@ Return the status [`StatusCodes`](@ref) in `status`. get_status_code(params::AbstractIntegrationParameters) = error("Not implemented for $(typeof(params))") +""" + get_metric(p::AbstractIntegrationParameters{M})::M where {M} + +Return the [`AbstractMetric`](@ref) `m::M` for which the integration parameters +have been specialised. +""" +get_metric(params::AbstractIntegrationParameters) = + error("Not implemented for $(typeof(params))") """ abstract type AbstractGeodesicPoint diff --git a/src/metrics/kerr-newman-ad.jl b/src/metrics/kerr-newman-ad.jl index 25200731..8d78499e 100644 --- a/src/metrics/kerr-newman-ad.jl +++ b/src/metrics/kerr-newman-ad.jl @@ -78,10 +78,11 @@ function geodesic_ode_problem( end function f(u::SVector{8,T}, p, λ) where {T} @inbounds let x = SVector{4,T}(@view(u[1:4])), v = SVector{4,T}(@view(u[5:8])) - dv = SVector{4,T}(geodesic_equation(m, x, v)) + _m = get_metric(p) + dv = SVector{4,T}(geodesic_equation(_m, x, v)) # add maxwell part dvf = if !(trace.q ≈ 0.0) - F = faraday_tensor(m, x) + F = faraday_tensor(_m, x) q_μ * (F * v) else zero(SVector{4,T}) @@ -95,7 +96,7 @@ function geodesic_ode_problem( f, u_init, time_domain, - IntegrationParameters(StatusCodes.NoStatus); + IntegrationParameters(m, StatusCodes.NoStatus); callback = callback, ) end diff --git a/src/precompile.jl b/src/precompile.jl index 33b8945c..aa713aae 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,3 +1,11 @@ +Base.precompile( + Tuple{ + typeof(_second_order_ode_f), + SVector{8,Float64}, + IntegrationParameters{KerrMetric{Float64}}, + Float64, + }, +) # time: 5.720331 Base.precompile( Tuple{ typeof(lineprofile), @@ -5,7 +13,7 @@ Base.precompile( SVector{4,Float64}, GeometricThinDisc{Float64}, }, -) # time: 4.8591228 +) # time: 2.5695212 Base.precompile( Tuple{ typeof(tracegeodesics), @@ -14,7 +22,119 @@ Base.precompile( SVector{4,Float64}, Vararg{Any}, }, -) # time: 1.1855857 +) # time: 1.4936496 +let fbody = try + __lookup_kwbody__( + which(tracegeodesics, (KerrMetric{Float64}, SVector{4,Float64}, Vararg{Any})), + ) + catch missing + end + if !ismissing(fbody) + precompile( + fbody, + ( + Float64, + Float64, + TraceGeodesic{Float64}, + Base.Pairs{Symbol,Union{},Tuple{},NamedTuple{(),Tuple{}}}, + typeof(tracegeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + Vararg{Any}, + ), + ) + end +end # time: 1.0984381 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:n_samples,),Tuple{Int64}}, + typeof(emissivity_profile), + KerrMetric{Float64}, + GeometricThinDisc{Float64}, + LampPostModel{Float64}, + }, +) # time: 0.9623503 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:n_samples,),Tuple{Int64}}, + typeof(tracecorona), + KerrMetric{Float64}, + GeometricThinDisc{Float64}, + LampPostModel{Float64}, + }, +) # time: 0.4676826 +let fbody = try + __lookup_kwbody__( + which(tracegeodesics, (KerrMetric{Float64}, SVector{4,Float64}, Vararg{Any})), + ) + catch missing + end + if !ismissing(fbody) + precompile( + fbody, + ( + Float64, + Float64, + TraceRadiativeTransfer{Float64}, + Base.Pairs{Symbol,Union{},Tuple{},NamedTuple{(),Tuple{}}}, + typeof(tracegeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + Vararg{Any}, + ), + ) + end +end # time: 0.3491159 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:n_samples,),Tuple{Int64}}, + typeof(tracegeodesics), + KerrMetric{Float64}, + LampPostModel{Float64}, + GeometricThinDisc{Float64}, + Vararg{Any}, + }, +) # time: 0.3443357 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:image_width, :image_height),Tuple{Int64,Int64}}, + typeof(rendergeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + GeometricThinDisc{Float64}, + Vararg{Any}, + }, +) # time: 0.19957814 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:trace,),Tuple{TraceRadiativeTransfer{Float64}}}, + typeof(tracegeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + SVector{4,Float64}, + Vararg{Any}, + }, +) # time: 0.13094269 +Base.precompile(Tuple{Type{RadialDiscProfile},Vector{Float64},Vector{Float64},Vector{Any}}) # time: 0.08506817 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{ + (:trace, :image_width, :image_height), + Tuple{TraceRadiativeTransfer{Float64},Int64,Int64}, + }, + typeof(rendergeodesics), + KerrMetric{Float64}, + SVector{4,Float64}, + GeometricThinDisc{Float64}, + Vararg{Any}, + }, +) # time: 0.0841068 let fbody = try __lookup_kwbody__( which( @@ -47,59 +167,54 @@ let fbody = try ), ) end -end # time: 0.9053782 +end # time: 0.036170956 Base.precompile( Tuple{ - typeof(Core.kwcall), - NamedTuple{ - (:trace, :image_width, :image_height), - Tuple{TraceRadiativeTransfer{Float64},Int64,Int64}, - }, - typeof(rendergeodesics), + typeof(tracegeodesics), KerrMetric{Float64}, - SVector{4,Float64}, - GeometricThinDisc{Float64}, + Vector{SVector{4,Float64}}, + Vector{SVector{4,Float64}}, Vararg{Any}, }, -) # time: 0.11555993 +) # time: 0.033662505 Base.precompile( Tuple{ - typeof(Core.kwcall), - NamedTuple{(:trace,),Tuple{TraceRadiativeTransfer{Float64}}}, - typeof(tracegeodesics), + typeof(tracing_configuration), + TraceGeodesic{Float64}, KerrMetric{Float64}, SVector{4,Float64}, SVector{4,Float64}, - Vararg{Any}, + GeometricThinDisc{Float64}, + Float64, }, -) # time: 0.04722405 +) # time: 0.030030003 Base.precompile( Tuple{ - typeof(tracegeodesics), + typeof(tracing_configuration), + TraceRadiativeTransfer{Float64}, KerrMetric{Float64}, - Vector{SVector{4,Float64}}, - Vector{SVector{4,Float64}}, - Vararg{Any}, + SVector{4,Float64}, + SVector{4,Float64}, + GeometricThinDisc{Float64}, + Float64, }, -) # time: 0.034065828 +) # time: 0.01967225 Base.precompile( Tuple{ typeof(Core.kwcall), - NamedTuple{(:n_samples,),Tuple{Int64}}, - typeof(tracegeodesics), + NamedTuple{ + (:trajectories, :save_on, :ensemble), + Tuple{Int64,Bool,EnsembleEndpointThreads}, + }, + typeof(tracing_configuration), + TraceGeodesic{Float64}, KerrMetric{Float64}, - LampPostModel{Float64}, + SVector{4,Float64}, + Function, GeometricThinDisc{Float64}, - Vararg{Any}, - }, -) # time: 0.025394669 -Base.precompile( - Tuple{ - typeof(update_integration_parameters!), - RadiativeTransferIntegrationParameters{Vector{Bool}}, - RadiativeTransferIntegrationParameters{Vector{Bool}}, + Float64, }, -) # time: 0.007881916 +) # time: 0.010621044 Base.precompile( Tuple{ typeof(Core.kwcall), @@ -115,4 +230,50 @@ Base.precompile( GeometricThinDisc{Float64}, Float64, }, -) # time: 0.007585501 +) # time: 0.008629832 +Base.precompile( + Tuple{ + typeof(tracing_configuration), + TraceGeodesic{Float64}, + KerrMetric{Float64}, + Vector{SVector{4,Float64}}, + Vector{SVector{4,Float64}}, + GeometricThinDisc{Float64}, + Float64, + }, +) # time: 0.007926913 +Base.precompile( + Tuple{ + typeof(update_integration_parameters!), + RadiativeTransferIntegrationParameters{KerrMetric{Float64},Vector{Bool}}, + RadiativeTransferIntegrationParameters{KerrMetric{Float64},Vector{Bool}}, + }, +) # time: 0.007825662 +Base.precompile( + Tuple{typeof(point_source_equitorial_disc_emissivity),Float64,Any,Float64,Float64}, +) # time: 0.006162999 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:gtol,),Tuple{Float64}}, + typeof(distance_to_disc), + GeometricThinDisc{Float64}, + SVector{9,Float64}, + }, +) # time: 0.003949623 +Base.precompile( + Tuple{ + typeof(Core.kwcall), + NamedTuple{(:gtol,),Tuple{Float64}}, + typeof(distance_to_disc), + GeometricThinDisc{Float64}, + SVector{8,Float64}, + }, +) # time: 0.001979128 +Base.precompile( + Tuple{ + typeof(set_status_code!), + IntegrationParameters{KerrMetric{Float64}}, + Gradus.StatusCodes.T, + }, +) # time: 0.001029539 diff --git a/src/solution-processing.jl b/src/solution-processing.jl index 283264a2..2210f986 100644 --- a/src/solution-processing.jl +++ b/src/solution-processing.jl @@ -128,5 +128,6 @@ function unpack_solution_full( end unpack_solution(gp::AbstractGeodesicPoint) = gp -unpack_solution(sol::SciMLBase.AbstractODESolution) = unpack_solution(sol.prob.f.f.m, sol) +unpack_solution(sol::SciMLBase.AbstractODESolution) = + unpack_solution(get_metric(sol.prob.p), sol) unpack_solution(simsol::SciMLBase.AbstractEnsembleSolution) = map(unpack_solution, simsol.u) diff --git a/src/tracing/geodesic-problem.jl b/src/tracing/geodesic-problem.jl index 9dc4a455..51a615a9 100644 --- a/src/tracing/geodesic-problem.jl +++ b/src/tracing/geodesic-problem.jl @@ -1,10 +1,15 @@ -mutable struct IntegrationParameters <: AbstractIntegrationParameters - status::StatusCodes.T + +struct IntegrationParameters{M} <: AbstractIntegrationParameters{M} + metric::M + status::MutStatusCode + IntegrationParameters(metric::M, status) where {M} = + new{M}(metric, MutStatusCode(status)) end set_status_code!(params::IntegrationParameters, status::StatusCodes.T) = - params.status = status -get_status_code(params::IntegrationParameters) = params.status + params.status[1] = status +get_status_code(params::IntegrationParameters) = params.status[1] +get_metric(params::IntegrationParameters) = params.metric """ geodesic_ode_problem( @@ -14,7 +19,7 @@ get_status_code(params::IntegrationParameters) = params.status vel, time_domain::Tuple, callback - ) + Returns an `OrdinaryDiffEq.ODEProblem{false}`, specifying the ODE problem to be solved. The precise problem depends on the [`AbstractTrace`](@ref) and [`AbstractMetric`](@ref) defined. @@ -62,23 +67,22 @@ function geodesic_ode_problem( time_domain, callback, ) - function geodesic_ode_f(u::SVector{8,T}, p, λ) where {T} - @inbounds let x = SVector{4,T}(@view(u[1:4])), v = SVector{4,T}(@view(u[5:8])) - dv = SVector{4,T}(geodesic_equation(m, x, v)) - vcat(v, dv) - end - end - u_init = vcat(pos, vel) ODEProblem{false}( - geodesic_ode_f, + _second_order_ode_f, u_init, time_domain, - IntegrationParameters(StatusCodes.NoStatus); + IntegrationParameters(m, StatusCodes.NoStatus); callback = callback, ) end +function _second_order_ode_f(u::SVector{8,T}, p, λ) where {T} + @inbounds let x = SVector{4,T}(@view(u[1:4])), v = SVector{4,T}(@view(u[5:8])) + dv = SVector{4,T}(geodesic_equation(get_metric(p), x, v)) + vcat(v, dv) + end +end """ assemble_tracing_problem(trace::AbstractTrace, config::TracingConfiguration) diff --git a/src/tracing/method-implementations/first-order.jl b/src/tracing/method-implementations/first-order.jl index 39f1f63d..0fc6f6ce 100644 --- a/src/tracing/method-implementations/first-order.jl +++ b/src/tracing/method-implementations/first-order.jl @@ -84,21 +84,23 @@ of motion in `p`. """ four_velocity(u, m::AbstractFirstOrderMetric, p) = error("Not implmented for $(typeof(m)).") -mutable struct FirstOrderIntegrationParameters{T} <: AbstractIntegrationParameters +mutable struct FirstOrderIntegrationParameters{M,T} <: AbstractIntegrationParameters{M} + metric::M L::T Q::T r::Int θ::Int changes::Vector{T} status::StatusCodes.T -end + FirstOrderIntegrationParameters(m::M, L, Q, sign_θ, ::Type{T}) where {M,T} = + new{M,T}(m, L, Q, -1, sign_θ, [0.0, 0.0], StatusCodes.NoStatus) -make_parameters(L, Q, sign_θ, ::Type{T}) where {T} = - FirstOrderIntegrationParameters{T}(L, Q, -1, sign_θ, [0.0, 0.0], StatusCodes.NoStatus) +end set_status_code!(params::FirstOrderIntegrationParameters, status::StatusCodes.T) = params.status = status get_status_code(params::FirstOrderIntegrationParameters) = params.status +get_metric(params::FirstOrderIntegrationParameters) = params.metric function update_integration_parameters!( p::FirstOrderIntegrationParameters, @@ -112,6 +114,10 @@ function update_integration_parameters!( p end +function _first_order_ode_f(u, p, λ) + SVector(four_velocity(u, get_metric(p), p)...) +end + function geodesic_ode_problem( ::TraceGeodesic, m::AbstractFirstOrderMetric{T}, @@ -122,13 +128,12 @@ function geodesic_ode_problem( ) where {S,T} L, Q = calc_lq(m, pos, vel) ODEProblem{false}( + _first_order_ode_f, pos, time_domain, - make_parameters(L, Q, vel[2], T); + FirstOrderIntegrationParameters(m, L, Q, vel[2], T); callback = callback, - ) do u, p, λ - SVector(four_velocity(u, m, p)...) - end + ) end convert_velocity_type(::StaticVector{S,T}, v) where {S,T} = convert(SVector{S,T}, v) diff --git a/src/tracing/radiative-transfer-problem.jl b/src/tracing/radiative-transfer-problem.jl index abd8d662..76e616df 100644 --- a/src/tracing/radiative-transfer-problem.jl +++ b/src/tracing/radiative-transfer-problem.jl @@ -26,19 +26,30 @@ end absorption_coefficient(m::AbstractMetric, d::AbstractAccretionGeometry, x, ν) = 0.0 emissivity_coefficient(m::AbstractMetric, d::AbstractAccretionGeometry, x, ν) = 0.0 -mutable struct RadiativeTransferIntegrationParameters{V} <: AbstractIntegrationParameters - status::StatusCodes.T +struct RadiativeTransferIntegrationParameters{M,V} <: AbstractIntegrationParameters{M} + metric::M + status::MutStatusCode within_geometry::V + RadiativeTransferIntegrationParameters( + metric::M, + status, + within_geometry::V, + ) where {M,V} = new{M,V}(metric, MutStatusCode(status), within_geometry) end -function _radiative_transfer_integration_parameters(status::StatusCodes.T, geometry) +function _radiative_transfer_integration_parameters( + metric::AbstractMetric, + status::StatusCodes.T, + geometry, +) within_geometry = map(!is_finite_disc, geometry) - RadiativeTransferIntegrationParameters(status, within_geometry) + RadiativeTransferIntegrationParameters(metric, status, within_geometry) end set_status_code!(params::RadiativeTransferIntegrationParameters, status::StatusCodes.T) = - params.status = status -get_status_code(params::RadiativeTransferIntegrationParameters) = params.status + params.status[1] = status +get_status_code(params::RadiativeTransferIntegrationParameters) = params.status[1] +get_metric(params::RadiativeTransferIntegrationParameters) = params.metric function update_integration_parameters!( @@ -127,10 +138,10 @@ function radiative_transfer_ode_problem( function f(u::SVector{9,T}, p, λ) where {T} @inbounds let x = SVector{4,T}(u[1:4]), k = SVector{4,T}(u[5:8]), I = u[9] - - dk = SVector{4,T}(geodesic_equation(m, x, k)) + _m = get_metric(p) + dk = SVector{4,T}(geodesic_equation(_m, x, k)) dI = _intensity_delta( - m, + _m, x, k, geometry, @@ -151,7 +162,7 @@ function radiative_transfer_ode_problem( f, u_init, time_domain, - _radiative_transfer_integration_parameters(StatusCodes.NoStatus, geometry); + _radiative_transfer_integration_parameters(m, StatusCodes.NoStatus, geometry); callback = callback, ) end From dd0b74dba5e7aabd7450b2eb6017c9302be2c950 Mon Sep 17 00:00:00 2001 From: fjebaker Date: Sun, 30 Jul 2023 12:03:38 +0100 Subject: [PATCH 2/2] fix: add DiffEqBase to dependencies --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index f70277d8..a06cf5a3 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.4.12" [deps] Buckets = "3235f445-51d8-4100-901d-5b23398ac3ab" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"