From 33d7456af0241cc040bc79e66683bd3d20bbe7c4 Mon Sep 17 00:00:00 2001 From: fjebaker Date: Wed, 26 Jul 2023 11:13:49 +0100 Subject: [PATCH] fix: prefer parameteric type over eltype --- src/GradusBase/geometry.jl | 12 ++++++------ src/metrics/kerr-newman-ad.jl | 10 +++++----- src/tracing/configuration.jl | 3 ++- .../method-implementations/auto-diff.jl | 8 ++++---- src/tracing/precision-solvers.jl | 18 +++++++++--------- src/tracing/utility.jl | 4 ++-- .../cunningham-transfer-functions.jl | 6 +++--- src/utils.jl | 4 ++-- 8 files changed, 33 insertions(+), 32 deletions(-) diff --git a/src/GradusBase/geometry.jl b/src/GradusBase/geometry.jl index 50f11e12..e4f0d399 100644 --- a/src/GradusBase/geometry.jl +++ b/src/GradusBase/geometry.jl @@ -1,8 +1,8 @@ # use this everywhere where we need a dot product so it's quick and easy to # change the underlying implementation -@fastmath function _fast_dot(x, y) +@fastmath function _fast_dot(x::AbstractVector{T}, y::AbstractVector) where {T} @assert size(x) == size(y) - res = zero(eltype(x)) + res = zero(T) @inbounds for i in eachindex(x) res += x[i] * y[i] end @@ -26,15 +26,15 @@ Project vector `v` onto `u` with metric `g`. Optional first argument may be """ mproject(g, v, u) = dotproduct(g, v, u) / propernorm(g, u) -function projectbasis(g, basis, v) - s = zero(SVector{4,eltype(g)}) +function projectbasis(g::AbstractMatrix{T}, basis, v) where {T} + s = zero(SVector{4,T}) for e in basis s += mproject(g, v, e) .* e end s end -function gramschmidt(v, basis, g; tol = 4eps(eltype(g))) +function gramschmidt(v, basis, g::AbstractMatrix{T}; tol = 4eps(T)) where {T} p = projectbasis(g, basis, v) while sum(p) > tol @@ -81,7 +81,7 @@ that is, returns a tuple that corresponds to the ``x^1, x^2, x^3, x^4`` coordina # ensure there is an initial direction # else just set it to r if sum(state) == 1 - state = SVector{4,eltype(state)}(1, 0, 0, 1) + state = SVector{4,Bool}(1, 0, 0, 1) end # store number of permutations of the space vectors needed to get # tetrad in the correct order at the end diff --git a/src/metrics/kerr-newman-ad.jl b/src/metrics/kerr-newman-ad.jl index 60a8a2ff..25200731 100644 --- a/src/metrics/kerr-newman-ad.jl +++ b/src/metrics/kerr-newman-ad.jl @@ -25,10 +25,10 @@ using ..MuladdMacro @SVector [tt, rr, θθ, ϕϕ, tϕ] end - function electromagnetic_potential(m, rθ) + function electromagnetic_potential(m, rθ::SVector{2,T}) where {T} (r, θ) = rθ Σ₀ = Σ(r, m.a, θ) - (r * m.Q / Σ₀) * SVector{4,eltype(rθ)}(1, 0, 0, -m.a * sin(θ)^2) + (r * m.Q / Σ₀) * SVector{4,T}(1, 0, 0, -m.a * sin(θ)^2) end end @@ -111,12 +111,12 @@ end function CircularOrbits.Ω( m::KerrNewmanMetric, - rθ; + rθ::SVector{2,T}, q = 0.0, μ = 1.0, contra_rotating = false, - Ω_init = eltype(rθ)(rθ[1] / 100), -) + Ω_init = T(rθ[1] / 100), +) where {T} g, jacs = Gradus.metric_jacobian(m, rθ) # only want the derivatives w.r.t. r ∂rg = jacs[:, 1] diff --git a/src/tracing/configuration.jl b/src/tracing/configuration.jl index 84f6490e..1ad97556 100644 --- a/src/tracing/configuration.jl +++ b/src/tracing/configuration.jl @@ -48,7 +48,8 @@ struct TracingConfiguration{ "Trajectories should be `nothing` when solving only a single geodesic problem.", ) end - _trajectories = eltype(velocity) <: SVector ? length(velocity) : trajectories + _trajectories = + (V <: AbstractVector && eltype(V) <: SVector) ? length(velocity) : trajectories _ensemble = restrict_ensemble(m, ensemble) new{ T, diff --git a/src/tracing/method-implementations/auto-diff.jl b/src/tracing/method-implementations/auto-diff.jl index 618666d3..9a8979f1 100644 --- a/src/tracing/method-implementations/auto-diff.jl +++ b/src/tracing/method-implementations/auto-diff.jl @@ -216,11 +216,11 @@ f(rθ), J ``` but non-allocating. """ -function metric_jacobian(m::AbstractStaticAxisSymmetric, rθ) +function metric_jacobian(m::AbstractStaticAxisSymmetric, rθ::SVector{2,T}) where {T} f = x -> metric_components(m, x) - T = typeof(ForwardDiff.Tag(f, eltype(rθ))) - ydual = _static_dual_eval(T, f, rθ) - (ForwardDiff.value.(T, ydual), _extract_jacobian(T, ydual, rθ)) + TT = typeof(ForwardDiff.Tag(f, T)) + ydual = _static_dual_eval(TT, f, rθ) + (ForwardDiff.value.(TT, ydual), _extract_jacobian(TT, ydual, rθ)) end @inbounds function geodesic_equation( diff --git a/src/tracing/precision-solvers.jl b/src/tracing/precision-solvers.jl index f9db67fa..0f065a93 100644 --- a/src/tracing/precision-solvers.jl +++ b/src/tracing/precision-solvers.jl @@ -34,16 +34,16 @@ function find_offset_for_radius( μ = 0.0, solver_opts..., ) - measure = gp -> begin + function _measure(gp::GeodesicPoint{T}) where {T} r = if gp.status == StatusCodes.IntersectedWithGeometry gp.x[2] * sin(gp.x[3]) else - zero(eltype(gp.x)) + zero(T) end rₑ - r end - velfunc = r -> begin + function _velfunc(r) α = r * cos(θₒ) β = r * sin(θₒ) constrain_all(m, u, map_impact_parameters(m, u, α, β), μ) @@ -53,7 +53,7 @@ function find_offset_for_radius( integ = _init_integrator( m, u, - velfunc(0.0), + _velfunc(0.0), d, (0.0, max_time); save_on = false, @@ -64,8 +64,8 @@ function find_offset_for_radius( if r < 0 return -1000 * r end - gp = _solve_reinit!(integ, vcat(u, velfunc(r))) - measure(gp) + gp = _solve_reinit!(integ, vcat(u, _velfunc(r))) + _measure(gp) end # use adaptive Order0 method : https://juliamath.github.io/Roots.jl/dev/reference/#Roots.Order0 @@ -77,10 +77,10 @@ function find_offset_for_radius( end - gp0 = _solve_reinit!(integ, vcat(u, velfunc(r0))) - if !isapprox(measure(gp0), 0.0, atol = 1e-4) + gp0 = _solve_reinit!(integ, vcat(u, _velfunc(r0))) + if !isapprox(_measure(gp0), 0.0, atol = 1e-4) @warn( - "Poor offset radius found for rₑ = $rₑ, θₑ = $θₒ (offset_max = $offset_max, measure = $(measure(gp0)))." + "Poor offset radius found for rₑ = $rₑ, θₑ = $θₒ (offset_max = $offset_max, measure = $(_measure(gp0)))." ) return NaN, gp0 end diff --git a/src/tracing/utility.jl b/src/tracing/utility.jl index 7786f12b..235d169a 100644 --- a/src/tracing/utility.jl +++ b/src/tracing/utility.jl @@ -86,8 +86,8 @@ function _map_impact_parameters(m::AbstractMetric, x, α, β) xfm(local_momentum(x[2], α, β)) end -function faraday_tensor(m::AbstractMetric, x) - ST = SVector{4,eltype(x)} +function faraday_tensor(m::AbstractMetric, x::AbstractVector{T}) where {T} + ST = SVector{4,T} dA = ForwardDiff.jacobian(t -> electromagnetic_potential(m, t), SVector(x[2], x[3])) ∂A = hcat(zeros(ST), dA, zeros(ST)) g = inv(metric(m, x)) diff --git a/src/transfer-functions/cunningham-transfer-functions.jl b/src/transfer-functions/cunningham-transfer-functions.jl index 43099e20..82d0f4ff 100644 --- a/src/transfer-functions/cunningham-transfer-functions.jl +++ b/src/transfer-functions/cunningham-transfer-functions.jl @@ -1,7 +1,7 @@ -function _adjust_extrema!(g) - g[1] = zero(eltype(g)) - g[end] = one(eltype(g)) +function _adjust_extrema!(g::AbstractArray{T}) where {T} + g[1] = zero(T) + g[end] = one(T) end function _make_sorted_with_adjustments!(g1, f1, t1, g2, f2, t2) diff --git a/src/utils.jl b/src/utils.jl index 0edff63c..0bc5ba9c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,9 +2,9 @@ function _make_interpolation(x, y) DataInterpolations.LinearInterpolation(y, x) end -@inline function _threaded_map(f, itr) +@inline function _threaded_map(f, itr::T) where {T} N = length(itr) - items = !(typeof(itr) <: AbstractArray) ? collect(itr) : itr + items = !(T <: AbstractArray) ? collect(itr) : itr output = Vector{Core.Compiler.return_type(f, Tuple{eltype(items)})}(undef, N) Threads.@threads for i = 1:N @inbounds output[i] = f(items[i])