Skip to content

Commit

Permalink
Merge pull request #190 from astro-group-bristol/fergus/speedup-trans…
Browse files Browse the repository at this point in the history
…fer-functions

Speedup transfer function calculations
  • Loading branch information
fjebaker authored Jun 20, 2024
2 parents 245cf7c + c0ee55f commit d7750b4
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 54 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Expand Down
1 change: 1 addition & 0 deletions src/Gradus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Optim
using DataInterpolations
using VoronoiCells
using Roots
import SimpleNonlinearSolve
using ProgressMeter
using Buckets
using QuadGK
Expand Down
10 changes: 6 additions & 4 deletions src/metrics/kerr-metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module __BoyerLindquistAD
using ..StaticArrays
using ..MuladdMacro

@muladd @fastmath begin
@muladd begin
Σ(r, a, θ) = r^2 + a^2 * cos(θ)^2
Δ(r, R, a) = r^2 + a^2 - R * r

Expand All @@ -15,13 +15,15 @@ using ..MuladdMacro
cosθ2 = (1 - sinθ2)
# slightly faster, especially when considering AD evals
Σ₀ = r^2 + a^2 * cosθ2
Σ₀⁻¹ = 1 / Σ₀
γ = sinθ2 * R * r * a

tt = -(1 - (R * r) / Σ₀)
tt = -(1 - (R * r) * Σ₀⁻¹)
rr = Σ₀ / Δ(r, R, a)
θθ = Σ₀
ϕϕ = sinθ2 * (r^2 + a^2 + (sinθ2 * R * r * a^2) / Σ₀)
ϕϕ = sinθ2 * (r^2 + a^2 + (γ * a) * Σ₀⁻¹)

= (-R * r * a * sinθ2) / Σ₀
= -γ * Σ₀⁻¹
@SVector [tt, rr, θθ, ϕϕ, tϕ]
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/solution-processing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct GeodesicPoint{T,A} <: AbstractGeodesicPoint{T}
end

function Base.show(io::IO, gp::GeodesicPoint)
print(io, "GeodesicPoint($(gp.x...))")
print(io, "GeodesicPoint($(gp.x))")
end

function Base.show(io::IO, ::MIME"text/plain", gp::GeodesicPoint)
Expand Down
66 changes: 46 additions & 20 deletions src/tracing/precision-solvers.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
function _offset_objective(r, (measure, velfunc, solve_geodesic, best)::Tuple)
r = abs(r)
gp = solve_geodesic(velfunc(r))
m = measure(gp)
if abs(m) < best[2]
best[2] = abs(m)
best[1] = r
end
m
end

function _find_offset_for_measure(
measure,
m::AbstractMetric,
Expand All @@ -13,29 +24,43 @@ function _find_offset_for_measure(
μ = 0,
solver_opts...,
)
function _velfunc(r::T) where {T}
function _velfunc(r)
α, β = _rθ_to_αβ(r, θₒ; α₀ = α₀, β₀ = β₀)
# need constrain_all here since `_solve_reinit!` doesn't normalize
constrain_all(m, x, map_impact_parameters(m, x, α, β), μ)
end

# init a reusable integrator
integ =
_init_integrator(m, x, _velfunc(0.0), d, max_time; save_on = false, solver_opts...)
integ = _init_integrator(
m,
x,
_velfunc(0.0),
d,
max_time;
save_on = false,
integrator_verbose = false,
solver_opts...,
)

function f(r)
if r < 0
return -1000 * r
end
gp = _solve_reinit!(integ, vcat(x, _velfunc(r)))
measure(gp)
function _solve_geodesic(v)
_solve_reinit!(integ, vcat(x, v))
end

# use adaptive Order0 method : https://juliamath.github.io/Roots.jl/dev/reference/#Roots.Order0
r0 = Roots.find_zero(f, initial_r, Roots.Order0(); atol = zero_atol)
best = eltype(x)[0.0, 1.0]
r0_candidate, resid = root_solve(
_offset_objective,
initial_r,
(measure, _velfunc, _solve_geodesic, best);
abstol = zero_atol,
)

gp0 = _solve_reinit!(integ, vcat(x, _velfunc(r0)))
r0, gp0
r0 = if best[2] < abs(resid)
best[1]
else
abs(r0_candidate)
end
gp0 = _solve_geodesic(_velfunc(r0))
r0, gp0, measure(gp0)
end

function _find_offset_for_radius(
Expand All @@ -48,20 +73,21 @@ function _find_offset_for_radius(
kwargs...,
)
function _measure(gp::GeodesicPoint{T}) where {T}
r = if gp.status == StatusCodes.IntersectedWithGeometry
_equatorial_project(gp.x)
else
zero(T)
r = _equatorial_project(gp.x)
if gp.status != StatusCodes.IntersectedWithGeometry
r = -r
end
rₑ - r
end
r0, gp0 = _find_offset_for_measure(_measure, m, x, d, θₒ; kwargs...)
r0, gp0, measure0 = _find_offset_for_measure(_measure, m, x, d, θₒ; kwargs...)

if warn && (r0 < 0)
@warn("Root finder found negative radius for rₑ = $rₑ, θₑ = $θₒ")
end
if !isapprox(_measure(gp0), 0.0, atol = 1e-4)
warn && @warn("Poor offset radius found for rₑ = $rₑ, θₑ = $θₒ")
if !isapprox(measure0, 0.0, atol = 1e-4)
warn && @warn(
"Poor offset radius found for rₑ = $rₑ, θₑ = $θₒ : measured $(_measure(gp0)) with r = $(r0)"
)
return NaN, gp0
end
r0, gp0
Expand Down
13 changes: 1 addition & 12 deletions src/tracing/radiative-transfer-problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,7 @@ function _intensity_delta(
indexed_geom = (enumerate(geometry.geometry)...,)
sum(indexed_geom) do args
i, geom = args
_intensity_delta(
m,
x,
k,
geom,
within,
I,
ν₀,
r_isco,
λ;
index = i,
)
_intensity_delta(m, x, k, geom, within, I, ν₀, r_isco, λ; index = i)
end
end

Expand Down
12 changes: 7 additions & 5 deletions src/transfer-functions/cunningham-transfer-functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ end
struct _TransferFunctionSetup{T}
h::T
θ_offset::T
"Tolerance for root finding"
zero_atol::T
"Unstable radius with respect to when finer tolerances are needed for the Jacobian evaluation"
unstable_radius::T
α₀::T
Expand All @@ -18,6 +20,7 @@ end
function _TransferFunctionSetup(
m::AbstractMetric{T};
θ_offset = T(0.6),
zero_atol = T(1e-7),
N = 80,
N_extrema = 17,
α₀ = 0,
Expand All @@ -28,6 +31,7 @@ function _TransferFunctionSetup(
setup = _TransferFunctionSetup{T}(
h,
θ_offset,
zero_atol,
Gradus.isco(m) + 1,
convert(T, α₀),
convert(T, β₀),
Expand Down Expand Up @@ -165,7 +169,6 @@ function _setup_workhorse_jacobian_with_kwargs(
rₑ;
max_time = 2 * x[2],
offset_max = 0.4rₑ + 10,
zero_atol = 1e-8,
redshift_pf = ConstPointFunctions.redshift(m, x),
jacobian_disc = d,
tracer_kwargs...,
Expand All @@ -190,7 +193,7 @@ function _setup_workhorse_jacobian_with_kwargs(
d,
rₑ,
θ;
zero_atol = zero_atol,
zero_atol = setup.zero_atol,
offset_max = offset_max,
max_time = max_time,
β₀ = setup.β₀,
Expand Down Expand Up @@ -239,7 +242,6 @@ function _rear_workhorse(
d::AbstractThickAccretionDisc,
rₑ;
max_time = 2 * x[2],
zero_atol = 1e-8,
offset_max = 0.4rₑ + 10,
kwargs...,
)
Expand All @@ -252,7 +254,6 @@ function _rear_workhorse(
plane,
rₑ;
max_time = max_time,
zero_atol = zero_atol,
offset_max = offset_max,
jacobian_disc = d,
kwargs...,
Expand All @@ -267,7 +268,7 @@ function _rear_workhorse(
rₑ,
θ;
initial_r = r,
zero_atol = zero_atol,
zero_atol = setup.zero_atol,
offset_max = offset_max,
max_time = max_time,
β₀ = setup.β₀,
Expand Down Expand Up @@ -394,6 +395,7 @@ function _search_extremal!(data::_TransferDataAccumulator, workhorse, offset)
error("i >= lastindex(data): $i >= $(lastindex(data))")
end
i += 1
# avoid poles
if abs(θ) < 1e-4 || abs(abs(θ) - π) < 1e-4
θ += 1e-4
end
Expand Down
42 changes: 42 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ function (interp::NaNLinearInterpolator)(x)
_interpolate(interp, x)
end

"""
_make_interpolation(x, y)
Interpolate `y` over `x`
Utility method to wrap some interpolation library and provide the same interface
for our needs.
"""
function _make_interpolation(x, y)
@assert size(x) == size(y) "size(x) = $(size(x)), size(y) = $(size(y))"
@assert issorted(x) "x must be sorted!"
Expand All @@ -52,6 +60,40 @@ end
end


"""
root_solve(f_objective, initial_value, args)
Wrapper to different root solving backends to make root solve fast and efficient
"""
function root_solve(
f_objective,
initial_value::T,
args;
abstol = 1e-9,
kwargs...,
) where {T<:Union{<:Number,<:SVector{1}}}
# Roots.find_zero(r -> f_objective(r, args), initial_value, Roots.Order0(); atol = abstol)
x0, f = if T <: Number
function _obj_wrapper(x::SVector, p)
@inbounds SVector{1,eltype(x)}(f_objective(x[1], p))
end
SVector{1}(initial_value), _obj_wrapper
else
initial_value, f_objective
end
prob = SimpleNonlinearSolve.NonlinearProblem{false}(f, x0, args)
sol = solve(
prob,
SimpleNonlinearSolve.SimpleBroyden();
abstol = abstol,
reltol = abstol,
maxiters = 500,
kwargs...,
)
sol.u[1], sol.resid[1]
end


@inline function _symmetric_matrix(comps::AbstractVector{T})::SMatrix{4,4,T} where {T}
@SMatrix [
comps[1] 0 0 comps[5]
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test-precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ u = SVector(0.0, 1000.0, π / 2, 0.0)
target = SVector(10.0, 0.001, 0.0)

α, β, accuracy = Gradus.impact_parameters_for_target(target, m, u)
@test α -0.0023917213583602584 rtol = 1e-3
@test α -0.0017523796800187875 rtol = 1e-3
@test β 10.969606493445841 rtol = 1e-3
@test accuracy 0.009064564426582843 rtol = 1e-3
@test accuracy 0.009091572709970781 rtol = 1e-3

# new target
target = SVector(10.0, deg2rad(40), -π / 4)
Expand Down
18 changes: 9 additions & 9 deletions test/smoke-tests/cunningham-transfer-functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function test_ctf1(a, angle, rₑ)
d,
rₑ,
;
chart = Gradus.chart_for_metric(m, 2 * x[2]),
chart = Gradus.chart_for_metric(m, 2 * x[2], closest_approach = 1.005),
N = 80,
)
end
Expand All @@ -21,16 +21,16 @@ function measure_ctf(ctf)
end

# test for different angles
@test measure_ctf(test_ctf1(0.998, 3, 4.0)) 0.12161376109873144 atol = 1e-4
@test measure_ctf(test_ctf1(0.998, 35, 4.0)) 0.10362951554307089 atol = 1e-4
@test measure_ctf(test_ctf1(0.998, 74, 4.0)) 0.054070356518218586 atol = 1e-4
@test measure_ctf(test_ctf1(0.998, 85, 4.0)) 0.034811715215212875 atol = 1e-4
@test measure_ctf(test_ctf1(0.998, 3, 4.0)) 0.12161376109873144 atol = 1e-3
@test measure_ctf(test_ctf1(0.998, 35, 4.0)) 0.10362951554307089 atol = 1e-3
@test measure_ctf(test_ctf1(0.998, 74, 4.0)) 0.054070356518218586 atol = 1e-3
@test measure_ctf(test_ctf1(0.998, 85, 4.0)) 0.034811715215212875 atol = 1e-3

# different radii
@test measure_ctf(test_ctf1(0.998, 30, 4.0)) 0.10779115390995794 atol = 1e-4
@test measure_ctf(test_ctf1(0.998, 30, 7.0)) 0.1202759989850966 atol = 1e-4
@test measure_ctf(test_ctf1(0.998, 30, 10.0)) 0.12461894214061674 atol = 1e-4
@test measure_ctf(test_ctf1(0.998, 30, 15.0)) 0.1275864358885266 atol = 1e-4
@test measure_ctf(test_ctf1(0.998, 30, 4.0)) 0.10779115390995794 atol = 1e-3
@test measure_ctf(test_ctf1(0.998, 30, 7.0)) 0.1202759989850966 atol = 1e-3
@test measure_ctf(test_ctf1(0.998, 30, 10.0)) 0.12461894214061674 atol = 1e-3
@test measure_ctf(test_ctf1(0.998, 30, 15.0)) 0.1275864358885266 atol = 1e-3

# large radii
@test measure_ctf(test_ctf1(0.998, 30, 300.0)) 0.13191798015557799 rtol = 1e-2
Expand Down
2 changes: 1 addition & 1 deletion test/transfer-functions/test-thick-disc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ d = ShakuraSunyaev(m)
tf = cunningham_transfer_function(m, x, d, 3.0; β₀ = 1.0)

total = sum(filter(!isnan, tf.f))
@test total 12.249315831376165 atol = 1e-4
@test total 12.245276038643347 atol = 1e-4

0 comments on commit d7750b4

Please sign in to comment.