Skip to content

Commit

Permalink
More performance fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Feb 3, 2025
1 parent 3a667d9 commit 7631484
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,9 @@ Add `x` to `y`. If `ismutabletype(T)`, then `increment!!(x, y) === x` must hold.
That is, `increment!!` will mutate `x`.
This must apply recursively if `T` is a composite type whose fields are mutable.
"""
increment!!(x::T, y::T) where {T} = increment_internal!!(IdDict{Any,Bool}(), x, y)
function increment!!(x::T, y::T) where {T}
return increment_internal!!(isbitstype(T) ? NoCache() : IdDict{Any,Bool}(), x, y)
end

increment_internal!!(::IncCache, ::NoTangent, ::NoTangent) = NoTangent()
increment_internal!!(::IncCache, x::T, y::T) where {T<:IEEEFloat} = x + y
Expand All @@ -619,10 +621,11 @@ function increment_internal!!(::IncCache, x::Ptr{T}, y::Ptr{T}) where {T}
end
@generated function increment_internal!!(c::IncCache, x::T, y::T) where {T<:Tuple}
inc_exprs = map(n -> :(increment_internal!!(c, x[$n], y[$n])), 1:fieldcount(T))
return Expr(:(::), Expr(:call, :tuple, inc_exprs...), T)
return Expr(:call, :tuple, inc_exprs...)
end
function increment_internal!!(c::IncCache, x::T, y::T) where {T<:NamedTuple}
return T(tuple_map((x, y) -> increment_internal!!(c, x, y), x, y))
@generated function increment_internal!!(c::IncCache, x::T, y::T) where {T<:NamedTuple}
inc_exprs = map(n -> :(increment_internal!!(c, x[$n], y[$n])), 1:fieldcount(T))
return Expr(:new, T, inc_exprs...)
end
function increment_internal!!(c::IncCache, x::T, y::T) where {T<:PossiblyUninitTangent}
is_init(x) && is_init(y) && return T(increment_internal!!(c, val(x), val(y)))
Expand Down

0 comments on commit 7631484

Please sign in to comment.