Skip to content

Commit

Permalink
Enzyme: simplify via mixedduplicated
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 21, 2024
1 parent b1d557b commit ecc1329
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
[compat]
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
Atomix = "0.1"
EnzymeCore = "0.7.1"
EnzymeCore = "0.7.5"
InteractiveUtils = "1.6"
LinearAlgebra = "1.6"
MacroTools = "0.5"
Expand Down
34 changes: 2 additions & 32 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,6 @@ module EnzymeExt
fwd_kernel(f, args...; ndrange, workgroupsize)
end


@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
if !any(ActiveTys)
return f
end
function inact(ctx, args2::Vararg{Any, N}) where N
args3 = ntuple(Val(N)) do i
Base.@_inline_meta
if ActiveTys[i]
args2[i][]
else
args2[i]
end
end
f(ctx, args3...)
end
return inact
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
kernel = func.val
f = kernel.f
Expand All @@ -102,11 +83,6 @@ module EnzymeExt
# TODO autodiff_deferred on the func.val
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)
FT = Const{Core.Typeof(f)}

arg_refs = ntuple(Val(N)) do i
Expand All @@ -120,7 +96,7 @@ module EnzymeExt
args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
MixedDuplicated(args[i].val, arg_refs[i])
else
args[i]
end
Expand Down Expand Up @@ -150,7 +126,7 @@ module EnzymeExt
args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
MixedDuplicated(args[i].val, arg_refs[i])
else
args[i]
end
Expand All @@ -159,12 +135,6 @@ module EnzymeExt
kernel = func.val
f = kernel.f

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)

ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

rev_kernel = similar(func.val, rev_cpu)
Expand Down

0 comments on commit ecc1329

Please sign in to comment.