Skip to content

Commit

Permalink
fix: prefer parameteric type over eltype
Browse files Browse the repository at this point in the history
  • Loading branch information
fjebaker committed Jul 26, 2023
1 parent e532970 commit 33d7456
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 32 deletions.
12 changes: 6 additions & 6 deletions src/GradusBase/geometry.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# use this everywhere where we need a dot product so it's quick and easy to
# change the underlying implementation
@fastmath function _fast_dot(x, y)
@fastmath function _fast_dot(x::AbstractVector{T}, y::AbstractVector) where {T}
@assert size(x) == size(y)
res = zero(eltype(x))
res = zero(T)
@inbounds for i in eachindex(x)
res += x[i] * y[i]
end
Expand All @@ -26,15 +26,15 @@ Project vector `v` onto `u` with metric `g`. Optional first argument may be
"""
mproject(g, v, u) = dotproduct(g, v, u) / propernorm(g, u)

function projectbasis(g, basis, v)
s = zero(SVector{4,eltype(g)})
function projectbasis(g::AbstractMatrix{T}, basis, v) where {T}
s = zero(SVector{4,T})
for e in basis
s += mproject(g, v, e) .* e
end
s
end

function gramschmidt(v, basis, g; tol = 4eps(eltype(g)))
function gramschmidt(v, basis, g::AbstractMatrix{T}; tol = 4eps(T)) where {T}
p = projectbasis(g, basis, v)

while sum(p) > tol
Expand Down Expand Up @@ -81,7 +81,7 @@ that is, returns a tuple that corresponds to the ``x^1, x^2, x^3, x^4`` coordina
# ensure there is an initial direction
# else just set it to r
if sum(state) == 1
state = SVector{4,eltype(state)}(1, 0, 0, 1)
state = SVector{4,Bool}(1, 0, 0, 1)
end
# store number of permutations of the space vectors needed to get
# tetrad in the correct order at the end
Expand Down
10 changes: 5 additions & 5 deletions src/metrics/kerr-newman-ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ using ..MuladdMacro
@SVector [tt, rr, θθ, ϕϕ, tϕ]
end

function electromagnetic_potential(m, rθ)
function electromagnetic_potential(m, rθ::SVector{2,T}) where {T}
(r, θ) =
Σ₀ = Σ(r, m.a, θ)
(r * m.Q / Σ₀) * SVector{4,eltype(rθ)}(1, 0, 0, -m.a * sin(θ)^2)
(r * m.Q / Σ₀) * SVector{4,T}(1, 0, 0, -m.a * sin(θ)^2)
end
end

Expand Down Expand Up @@ -111,12 +111,12 @@ end

function CircularOrbits.Ω(
m::KerrNewmanMetric,
;
::SVector{2,T},
q = 0.0,
μ = 1.0,
contra_rotating = false,
Ω_init = eltype(rθ)(rθ[1] / 100),
)
Ω_init = T(rθ[1] / 100),
) where {T}
g, jacs = Gradus.metric_jacobian(m, rθ)
# only want the derivatives w.r.t. r
∂rg = jacs[:, 1]
Expand Down
3 changes: 2 additions & 1 deletion src/tracing/configuration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct TracingConfiguration{
"Trajectories should be `nothing` when solving only a single geodesic problem.",
)
end
_trajectories = eltype(velocity) <: SVector ? length(velocity) : trajectories
_trajectories =
(V <: AbstractVector && eltype(V) <: SVector) ? length(velocity) : trajectories
_ensemble = restrict_ensemble(m, ensemble)
new{
T,
Expand Down
8 changes: 4 additions & 4 deletions src/tracing/method-implementations/auto-diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ f(rθ), J
```
but non-allocating.
"""
function metric_jacobian(m::AbstractStaticAxisSymmetric, rθ)
function metric_jacobian(m::AbstractStaticAxisSymmetric, rθ::SVector{2,T}) where {T}
f = x -> metric_components(m, x)
T = typeof(ForwardDiff.Tag(f, eltype(rθ)))
ydual = _static_dual_eval(T, f, rθ)
(ForwardDiff.value.(T, ydual), _extract_jacobian(T, ydual, rθ))
TT = typeof(ForwardDiff.Tag(f, T))
ydual = _static_dual_eval(TT, f, rθ)
(ForwardDiff.value.(TT, ydual), _extract_jacobian(TT, ydual, rθ))
end

@inbounds function geodesic_equation(
Expand Down
18 changes: 9 additions & 9 deletions src/tracing/precision-solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ function find_offset_for_radius(
μ = 0.0,
solver_opts...,
)
measure = gp -> begin
function _measure(gp::GeodesicPoint{T}) where {T}
r = if gp.status == StatusCodes.IntersectedWithGeometry
gp.x[2] * sin(gp.x[3])
else
zero(eltype(gp.x))
zero(T)
end
rₑ - r
end

velfunc = r -> begin
function _velfunc(r)
α = r * cos(θₒ)
β = r * sin(θₒ)
constrain_all(m, u, map_impact_parameters(m, u, α, β), μ)
Expand All @@ -53,7 +53,7 @@ function find_offset_for_radius(
integ = _init_integrator(
m,
u,
velfunc(0.0),
_velfunc(0.0),
d,
(0.0, max_time);
save_on = false,
Expand All @@ -64,8 +64,8 @@ function find_offset_for_radius(
if r < 0
return -1000 * r
end
gp = _solve_reinit!(integ, vcat(u, velfunc(r)))
measure(gp)
gp = _solve_reinit!(integ, vcat(u, _velfunc(r)))
_measure(gp)
end

# use adaptive Order0 method : https://juliamath.github.io/Roots.jl/dev/reference/#Roots.Order0
Expand All @@ -77,10 +77,10 @@ function find_offset_for_radius(
end


gp0 = _solve_reinit!(integ, vcat(u, velfunc(r0)))
if !isapprox(measure(gp0), 0.0, atol = 1e-4)
gp0 = _solve_reinit!(integ, vcat(u, _velfunc(r0)))
if !isapprox(_measure(gp0), 0.0, atol = 1e-4)
@warn(
"Poor offset radius found for rₑ = $rₑ, θₑ = $θₒ (offset_max = $offset_max, measure = $(measure(gp0)))."
"Poor offset radius found for rₑ = $rₑ, θₑ = $θₒ (offset_max = $offset_max, measure = $(_measure(gp0)))."
)
return NaN, gp0
end
Expand Down
4 changes: 2 additions & 2 deletions src/tracing/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ function _map_impact_parameters(m::AbstractMetric, x, α, β)
xfm(local_momentum(x[2], α, β))
end

function faraday_tensor(m::AbstractMetric, x)
ST = SVector{4,eltype(x)}
function faraday_tensor(m::AbstractMetric, x::AbstractVector{T}) where {T}
ST = SVector{4,T}
dA = ForwardDiff.jacobian(t -> electromagnetic_potential(m, t), SVector(x[2], x[3]))
∂A = hcat(zeros(ST), dA, zeros(ST))
g = inv(metric(m, x))
Expand Down
6 changes: 3 additions & 3 deletions src/transfer-functions/cunningham-transfer-functions.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

function _adjust_extrema!(g)
g[1] = zero(eltype(g))
g[end] = one(eltype(g))
function _adjust_extrema!(g::AbstractArray{T}) where {T}
g[1] = zero(T)
g[end] = one(T)
end

function _make_sorted_with_adjustments!(g1, f1, t1, g2, f2, t2)
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ function _make_interpolation(x, y)
DataInterpolations.LinearInterpolation(y, x)
end

@inline function _threaded_map(f, itr)
@inline function _threaded_map(f, itr::T) where {T}
N = length(itr)
items = !(typeof(itr) <: AbstractArray) ? collect(itr) : itr
items = !(T <: AbstractArray) ? collect(itr) : itr
output = Vector{Core.Compiler.return_type(f, Tuple{eltype(items)})}(undef, N)
Threads.@threads for i = 1:N
@inbounds output[i] = f(items[i])
Expand Down

0 comments on commit 33d7456

Please sign in to comment.