Skip to content

Commit

Permalink
add more types supported with Bundle helper
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Aug 21, 2023
1 parent 8dd45c0 commit b60acd4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
25 changes: 21 additions & 4 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@ import AbstractDifferentiation as AD
struct DiffractorForwardBackend <: AD.AbstractForwardMode
end

bundle(x::Number, dx) = TaylorBundle{1}(x, (dx,))
bundle(x::Tuple, dx) = CompositeBundle{1}(x, dx)
bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,)) # TODO check me
# TODO: other types of primal
"""
bundle(primal, tangent)
Wraps a primal up with a tangent into the appropriate kind of `AbstractBundle{1}`.
This is more or less the Diffractor equivelent of ForwardDiff.jl's `Dual` type.
"""
function bundle end
bundle(x, dx::ChainRulesCore.AbstractZero) = UniformBundle{1, typeof(x), typeof(dx)}(x, dx)
bundle(x::Number, dx::Number) = TaylorBundle{1}(x, (dx,))
bundle(x::AbstractArray{<:Number}, dx) = TaylorBundle{1}(x, (dx,))
bundle(x::AbstractArray, dx) = error("Nonnumeric arrays not implemented, that type is a mess")

Check warning on line 15 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L15

Added line #L15 was not covered by tests
bundle(x::P, dx::Tangent{P}) where P = _bundle(x, ChainRulesCore.canonicalize(dx))

"helper that assumes tangent is in canonical form"
function _bundle(x::P, dx::Tangent{P}) where P
# SoA to AoS flip (hate this, hate it even more cos we just undo it later when we hit chainrules)
the_bundle = ntuple(Val{fieldcount(P)}()) do ii
bundle(getfield(x, ii), getproperty(dx, ii))
end
return CompositeBundle{1, P}(the_bundle)
end


AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...)
Expand Down
21 changes: 20 additions & 1 deletion test/AbstractDifferentiationTests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
using AbstractDifferentiation, Diffractor, Test, LinearAlgebra
using AbstractDifferentiation, Diffractor, Test, LinearAlgebra, ChainRulesCore
import AbstractDifferentiation as AD
backend = Diffractor.DiffractorForwardBackend()

@testset "bundle" begin
bundle = Diffractor.bundle

@test bundle(1.0, 2.0) isa Diffractor.TaylorBundle{1}
@test bundle([1.0, 2.0], [2.0, 3.0]) isa Diffractor.TaylorBundle{1}
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.CompositeBundle{1}
@test bundle(1.1, ChainRulesCore.ZeroTangent()) isa Diffractor.ZeroBundle{1}
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.CompositeBundle{1}

# noncanonical structural tangent
b = bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(second=Tangent{Pair{Float64, Float64}}(second=2.0, first=1.0)))
t = Diffractor.first_partial(b)
@test b isa Diffractor.CompositeBundle{1}
@test iszero(t.first)
@test t.second.first == 1.0
@test t.second.second == 2.0
end

@testset "basics" begin
@test AD.derivative(backend, +, 1.5, 10.0) == (1.0, 1.0)
@test AD.derivative(backend, *, 1.5, 10.0) == (10.0, 1.5)
Expand Down Expand Up @@ -50,3 +68,4 @@ include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_util
end
end
end

0 comments on commit b60acd4

Please sign in to comment.