Skip to content

Commit 0524b08

Browse files
committed
if all partials AbstractZero don't call frule
1 parent e9c1348 commit 0524b08

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/stage1/forward.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,22 @@ function shuffle_base(r)
109109
end
110110

111111
function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
112-
r = frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
112+
r = _frule(map(first_partial, args), map(primal, args)...)
113113
if r === nothing
114114
return ∂☆recurse{1}()(args...)
115115
else
116116
return shuffle_base(r)
117117
end
118118
end
119119

120+
_frule(partials, primals...) = frule(DiffractorRuleConfig(), partials, primals...)
121+
function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...)
122+
# frules are linear in partials, so zero maps to zero, no need to evaluate the frule
123+
# If all partials are immutable AbstractZero subtyoes we know we don't have to worry about a mutating frule either
124+
r = f(primal_args...)
125+
return r, zero_tangent(r)
126+
end
127+
120128
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
121129
bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args)
122130
result = ∂☆internal{1}()(bundles...)

0 commit comments

Comments
 (0)