Skip to content

Commit c00ceb7

Browse files
authored
Merge pull request #219 from JuliaDiff/ox/mutation2
Forward mode mutable struct support
2 parents 6f1bdf0 + 8badafd commit c00ceb7

17 files changed

+228
-58
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1818
[compat]
1919
AbstractDifferentiation = "0.5"
2020
ChainRules = "1.44.6"
21-
ChainRulesCore = "1.15.3"
21+
ChainRulesCore = "1.20"
2222
Combinatorics = "1"
2323
Cthulhu = "2"
2424
OffsetArrays = "1"

src/codegen/forward.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function fwd_transform!(ci, mi, nargs, N)
3535
args = map(stmt.args) do stmt
3636
emit!(mapstmt!(stmt))
3737
end
38-
return Expr(:call, Core._apply_iterate, FwdIterate(ZeroBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
38+
return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
3939
elseif isa(stmt, SSAValue)
4040
return SSAValue(ssa_mapping[stmt.id])
4141
elseif isa(stmt, Core.SlotNumber)
@@ -64,14 +64,14 @@ function fwd_transform!(ci, mi, nargs, N)
6464
# Always disable `@inbounds`, as we don't actually know if the AD'd
6565
# code is truly `@inbounds` or not.
6666
elseif isexpr(stmt, :boundscheck)
67-
return ZeroBundle{N}(true)
67+
return DNEBundle{N}(true)
6868
else
6969
# Fallback case, for literals.
7070
# If it is an Expr, then it is not a literal
7171
if isa(stmt, Expr)
7272
error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt")
7373
end
74-
return Expr(:call, ZeroBundle{N}, stmt)
74+
return Expr(:call, zero_bundle{N}(), stmt)
7575
end
7676
end
7777

src/codegen/forward_demand.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
264264
return transform!(ir, arg, order, maparg)
265265
elseif isa(arg, GlobalRef)
266266
@assert isconst(arg)
267-
return ZeroBundle{order}(getfield(arg.mod, arg.name))
267+
return zero_bundle{order}()(getfield(arg.mod, arg.name))
268268
elseif isa(arg, QuoteNode)
269-
return ZeroBundle{order}(arg.value)
269+
return zero_bundle{order}()(arg.value)
270270
end
271271
@assert !isa(arg, Expr)
272-
return ZeroBundle{order}(arg)
272+
return zero_bundle{order}()(arg)
273273
end
274274

275275
for (ssa, (order, custom)) in enumerate(ssa_orders)
@@ -309,7 +309,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
309309
stmt = insert_node!(ir, ssa, NewInstruction(inst))
310310
end
311311

312-
replace_call!(ir, SSAValue(ssa), Expr(:call, ZeroBundle{order}, stmt))
312+
replace_call!(ir, SSAValue(ssa), Expr(:call, zero_bundle{order}(), stmt))
313313
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
314314
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
315315
inst[:type] = Any
@@ -329,7 +329,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
329329
inst[:type] = Any
330330
inst[:flag] |= CC.IR_FLAG_REFINED
331331
else
332-
val = ZeroBundle{order}(inst[:inst])
332+
val = zero_bundle{order}()(inst[:inst])
333333
inst[:inst] = val
334334
inst[:type] = Const(val)
335335
end
@@ -362,6 +362,6 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met
362362
rt = CC._ir_abstract_constant_propagation(interp, irsv)
363363

364364
ir = compact!(ir)
365-
365+
366366
return ir
367367
end

src/extra_rules.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,13 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x:
172172
end
173173

174174
function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
175-
SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing)
175+
Δx = SArray{S, T, N, L}(ChainRulesCore.backing(∂x))
176+
SArray{S, T, N, L}(x), Δx
176177
end
177178

179+
Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds)
180+
Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind]
181+
178182
function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
179183
SArray{S, T, N, L}(x), SArray{S}(∂x)
180184
end
@@ -262,3 +266,18 @@ Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDi
262266

263267
# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495
264268
ChainRulesCore._backing_error(P::Type{<:Base.Pairs}, G::Type{<:NamedTuple}, E::Type{<:AbstractDict}) = nothing
269+
270+
# Needed for higher order so we don't see the `backing` field of StructuralTangents, just the contents
271+
# SHould these be in ChainRules/ChainRulesCore?
272+
# is this always the right behavour, or just because of how we do higher order
273+
function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getproperty), strct::StructuralTangent, sym::Union{Int,Symbol}, inbounds)
274+
return (getproperty(strct, sym, inbounds), getproperty(Δ, sym))
275+
end
276+
277+
278+
function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setproperty!), obj::MutableTangent, field, x)
279+
ȯbj::MutableTangent
280+
y = setproperty!(obj, field, x)
281+
= setproperty!(ȯbj, field, ẋ)
282+
return y, ẏ
283+
end

src/higher_fwd_rules.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@ end
1919

2020
jeval(j, x) = j(x)
2121
for f in (sin, cos, exp)
22-
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N}
22+
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(f)}, x::TaylorBundle{N}) where {N}
2323
njet(Val{N}(), primal(fb), primal(x))(x)
2424
end
25-
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::ZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M}
25+
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, fb::AbstractZeroBundle{M, typeof(f)}, x::TaylorBundle{M}) where {N, M}
2626
∂⃖ₙ(jeval, njet(Val{N+M}(), primal(fb), primal(x)), x)
2727
end
2828
end
2929

3030
# TODO: It's a bit embarassing that we need to write these out, but currently the
3131
# compiler is not strong enough to automatically lift the frule. Let's hope we
3232
# can delete these in the near future.
33-
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
33+
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
3434
TaylorBundle{N}(primal(a) + primal(b),
3535
map(+, a.tangent.coeffs, b.tangent.coeffs))
3636
end
3737

38-
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::ZeroBundle{N}) where {N}
38+
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::AbstractZeroBundle{N}) where {N}
3939
TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs)
4040
end
4141

42-
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
42+
function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
4343
TaylorBundle{N}(primal(a) - primal(b),
4444
map(-, a.tangent.coeffs, b.tangent.coeffs))
4545
end

src/stage1/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99

1010
n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x))
1111

12-
function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
12+
function (∂ₙ::∂☆{N})(zc::AbstractZeroBundle{N, typeof(copy)},
1313
bc::ATB{N, <:Broadcasted}) where {N}
1414
bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc)
1515
args = n_getfield(∂ₙ, bc, :args)

src/stage1/forward.jl

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,33 @@ struct ∂☆shuffle{N}; end
9898

9999
function shuffle_base(r)
100100
(primal, dual) = r
101-
if isa(dual, Union{NoTangent, ZeroTangent})
101+
if dual isa NoTangent
102102
UniformBundle{1}(primal, dual)
103103
else
104+
if dual isa ZeroTangent # Normalize zero for type-stability reasons
105+
dual = zero_tangent(primal)
106+
end
104107
TaylorBundle{1}(primal, (dual,))
105108
end
106109
end
107110

108111
function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
109-
r = frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
112+
r = _frule(map(first_partial, args), map(primal, args)...)
110113
if r === nothing
111114
return ∂☆recurse{1}()(args...)
112115
else
113116
return shuffle_base(r)
114117
end
115118
end
116119

120+
_frule(partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...)
121+
function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...)
122+
# frules are linear in partials, so zero maps to zero, no need to evaluate the frule
123+
# If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either
124+
r = f(primal_args...)
125+
return r, zero_tangent(r)
126+
end
127+
117128
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
118129
bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args)
119130
result = ∂☆internal{1}()(bundles...)
@@ -131,12 +142,12 @@ end
131142
function (::∂☆internal{N})(f::AbstractZeroBundle{N}, args::AbstractZeroBundle{N}...) where {N}
132143
f_v = primal(f)
133144
args_v = map(primal, args)
134-
return ZeroBundle{N}(f_v(args_v...))
145+
return zero_bundle{N}()(f_v(args_v...))
135146
end
136147
function (::∂☆internal{1})(f::AbstractZeroBundle{1}, args::AbstractZeroBundle{1}...)
137148
f_v = primal(f)
138149
args_v = map(primal, args)
139-
return ZeroBundle{1}(f_v(args_v...))
150+
return zero_bundle{1}()(f_v(args_v...))
140151
end
141152

142153
function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
@@ -193,25 +204,25 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
193204
end
194205
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
195206

196-
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
207+
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
197208
∂vararg{N}()(map(FwdMap(f), destructure(tup))...)
198209
end
199210

200-
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
211+
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
201212
# TODO: This could do an inplace map! to avoid the extra rebundling
202213
rebundle(map(FwdMap(f), map(unbundle, args)...))
203214
end
204215

205-
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N}
216+
function (::∂☆{N})(::AbstractZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N}...) where {N}
206217
∂☆recurse{N}()(ZeroBundle{N, typeof(map)}(map), f, args...)
207218
end
208219

209220

210-
function (::∂☆{N})(f::ZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
221+
function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
211222
ifelse(arg.primal, args...)
212223
end
213224

214-
function (::∂☆{N})(f::ZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
225+
function (::∂☆{N})(f::AbstractZeroBundle{N, typeof(Core.ifelse)}, arg::ATB{N, Bool}, args::ATB{N}...) where {N}
215226
Core.ifelse(arg.primal, args...)
216227
end
217228

@@ -233,48 +244,48 @@ end
233244
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
234245
end
235246

236-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
247+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
237248
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
238249
end
239250

240251

241-
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
252+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
242253
r = iterate(destructure(t))
243254
r === nothing && return ZeroBundle{N}(nothing)
244255
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
245256
end
246257

247-
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
258+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
248259
r = iterate(destructure(t), primal(a), map(primal, args)...)
249260
r === nothing && return ZeroBundle{N}(nothing)
250261
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
251262
end
252263

253-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
264+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
254265
r = Base.indexed_iterate(destructure(t), primal(i))
255266
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
256267
end
257268

258-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
269+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
259270
r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...)
260271
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
261272
end
262273

263-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}
274+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}
264275
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
265276
end
266277

267-
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N}
278+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::AbstractZeroBundle) where {N}
268279
field_ind = primal(i)
269280
the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N)
270281
TaylorBundle{N}(primal(t)[field_ind], the_partials)
271282
end
272283

273-
function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
284+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
274285
DNEBundle{N}(typeof(primal(x)))
275286
end
276287

277-
function (this::∂☆{N})(f::ZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N}
288+
function (this::∂☆{N})(f::AbstractZeroBundle{N, Core.IntrinsicFunction}, args::ATB{N}...) where {N}
278289
ff = primal(f)
279290
if ff === Base.not_int
280291
DNEBundle{N}(ff(map(primal, args)...))

src/stage1/generated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ end
390390
lifted_getfield(x::ZeroTangent, s) = ZeroTangent()
391391
lifted_getfield(x::NoTangent, s) = NoTangent()
392392

393-
lifted_getfield(x::Tangent, s) = getproperty(x, s)
393+
lifted_getfield(x::StructuralTangent, s) = getproperty(x, s)
394394

395395
function lifted_getfield(x::Tangent{<:Tangent{T}}, s) where T
396396
bb = getfield(x.backing, 1)

src/stage1/mixed.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,26 @@ function (f::FwdIterate)(arg::ATB{N}, st) where {N}
7070
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
7171
end
7272
73-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
73+
function (this::∂☆{N})(::AbstractZeroBundle{N, typeof(Core._apply_iterate)}, iterate::ATB{N}, f::ATB{N}, args::ATB{N}...) where {N}
7474
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
7575
end
7676
=#
7777

78-
function (this::∂⃖{N})(that::∂☆{M}, ::ZeroBundle{M, typeof(Core._apply_iterate)},
78+
function (this::∂⃖{N})(that::∂☆{M}, ::AbstractZeroBundle{M, typeof(Core._apply_iterate)},
7979
iterate, f, args::ATB{M, <:Tuple}...) where {N, M}
8080
@assert primal(iterate) === Base.iterate
8181
x, ∂⃖f = Core._apply_iterate(FwdIterate(iterate), this, (that, f), args...)
8282
return x, ApplyOdd{1, c_order(N)}(UnApply{map(x->length(primal(x)), args)}(), ∂⃖f)
8383
end
8484

8585

86-
function ChainRules.rrule(∂::∂☆{N}, m::ZeroBundle{N, typeof(map)}, p::ZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N}
86+
function ChainRules.rrule(∂::∂☆{N}, m::AbstractZeroBundle{N, typeof(map)}, p::AbstractZeroBundle{N, typeof(+)}, A::ATB{N}, B::ATB{N}) where {N}
8787
(m, p, A, B), Δ->(NoTangent(), NoTangent(), NoTangent(), Δ, Δ)
8888
end
8989

9090
mapev_unbundled(_, js, a) = rebundle(mapev(js, unbundle(a)))
91-
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::ZeroBundle{M, typeof(map)},
92-
f::ZeroBundle{M}, a::ATB{M, <:Array}) where {N, M}
91+
function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::AbstractZeroBundle{M, typeof(map)},
92+
f::AbstractZeroBundle{M}, a::ATB{M, <:Array}) where {N, M}
9393
@assert Base.issingletontype(typeof(primal(f)))
9494
js = map(primal(a)) do x
9595
∂f = ∂☆{N+M}()(ZeroBundle{N+M}(primal(f)),

src/stage1/recurse_fwd.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,31 @@ struct ∂☆new{N}; end
1515
function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
1616
primal_args = map(primal, xs)
1717
the_primal = _construct(B, primal_args)
18-
1918
tangent_tup = map(first_partial, xs)
2019
the_partial = if B<:Tuple
2120
Tangent{B, typeof(tangent_tup)}(tangent_tup)
2221
else
2322
names = fieldnames(B)
2423
tangent_nt = NamedTuple{names}(tangent_tup)
25-
Tangent{B, typeof(tangent_nt)}(tangent_nt)
24+
StructuralTangent{B}(tangent_nt)
2625
end
27-
return TaylorBundle{1, B}(the_primal, (the_partial,))
26+
B2 = typeof(the_primal) # HACK: if the_primal actually has types in it then we want to make sure we get DataType not Type(...)
27+
return TaylorBundle{1, B2}(the_primal, (the_partial,))
2828
end
2929

3030
function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
3131
primal_args = map(primal, xs)
3232
the_primal = _construct(B, primal_args)
33-
3433
the_partials = ntuple(Val{N}()) do ii
35-
iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking
3634
tangent_tup = map(x->partial(x, ii), xs)
3735
tangent = if B<:Tuple
38-
Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup)
36+
Tangent{B, typeof(tangent_tup)}(tangent_tup)
3937
else
38+
# No matter the order we use `StructuralTangent{B}` for the partial
39+
# It follows all required properties of the tangent to the n-1th order tangent
4040
names = fieldnames(B)
4141
tangent_nt = NamedTuple{names}(tangent_tup)
42-
Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt)
42+
StructuralTangent{B}(tangent_nt)
4343
end
4444
return tangent
4545
end
@@ -50,7 +50,7 @@ _construct(::Type{B}, args) where B<:Tuple = B(args)
5050
# Hack for making things that do not have public constructors constructable:
5151
@generated _construct(B::Type, args) = Expr(:splatnew, :B, :args)
5252

53-
@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))
53+
@generated (::∂☆new{N})(B::Type) where {N} = return :(zero_bundle{$N}()($(Expr(:new, :B))))
5454

5555
# Sometimes we don't know whether or not we need to the ZeroBundle when doing
5656
# the transform, so this can happen - allow it for now.

0 commit comments

Comments
 (0)