Skip to content

Commit

Permalink
More fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 25, 2023
1 parent 61310ce commit ee36810
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
12 changes: 3 additions & 9 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
end

ddsts = dst.dval
dsrcs = src.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

if EnzymeCore.EnzymeRules.width(config) == 1
ddsts = (ddsts,)
Expand Down Expand Up @@ -182,7 +182,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
end

ddsts = dst.dval
dsrcs = src.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

if EnzymeCore.EnzymeRules.width(config) == 1
ddsts = (ddsts,)
Expand Down Expand Up @@ -322,12 +322,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
keep = nothing
end

# Cache idx if its overwritten
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
&& !(typeof(src) <: EnzymeCore.Const)
&& !(typeof(dst) <: EnzymeCore.Const)
) ? copy(idx.val) : nothing

return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep)
end

Expand All @@ -336,7 +330,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
val = convert(T, 1/(1-p.val))

ddsts = dst.dval
dsrcs = src.dval
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval

if EnzymeCore.EnzymeRules.width(config) == 1
ddsts = (ddsts,)
Expand Down
2 changes: 1 addition & 1 deletion test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ end

EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue

EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (idx, EnzymeCore.Const))
EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (cdims, EnzymeCore.Const))
end
end
end

0 comments on commit ee36810

Please sign in to comment.