diff --git a/Project.toml b/Project.toml index 048f08d1..2de3a8aa 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249" [compat] Adapt = "0.4, 1.0, 2.0, 3.0, 4" Atomix = "0.1" -EnzymeCore = "0.7.1" +EnzymeCore = "0.7.5" InteractiveUtils = "1.6" LinearAlgebra = "1.6" MacroTools = "0.5" diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl index 4cd30872..1c84c607 100644 --- a/ext/EnzymeExt.jl +++ b/ext/EnzymeExt.jl @@ -70,25 +70,6 @@ module EnzymeExt fwd_kernel(f, args...; ndrange, workgroupsize) end - - @inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys} - if !any(ActiveTys) - return f - end - function inact(ctx, args2::Vararg{Any, N}) where N - args3 = ntuple(Val(N)) do i - Base.@_inline_meta - if ActiveTys[i] - args2[i][] - else - args2[i] - end - end - f(ctx, args3...) - end - return inact - end - function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N kernel = func.val f = kernel.f @@ -102,11 +83,6 @@ module EnzymeExt # TODO autodiff_deferred on the func.val ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) - tup = Val(ntuple(Val(N)) do i - Base.@_inline_meta - args[i] isa Active - end) - f = make_active_byref(f, tup) FT = Const{Core.Typeof(f)} arg_refs = ntuple(Val(N)) do i @@ -120,7 +96,7 @@ module EnzymeExt args2 = ntuple(Val(N)) do i Base.@_inline_meta if args[i] isa Active - Duplicated(Ref(args[i].val), arg_refs[i]) + MixedDuplicated(args[i].val, arg_refs[i]) else args[i] end @@ -150,7 +126,7 @@ module EnzymeExt args2 = ntuple(Val(N)) do i Base.@_inline_meta if args[i] isa Active - Duplicated(Ref(args[i].val), arg_refs[i]) + MixedDuplicated(args[i].val, arg_refs[i]) else args[i] end @@ -159,12 +135,6 @@ module EnzymeExt kernel = func.val f = kernel.f - tup = Val(ntuple(Val(N)) do i - Base.@_inline_meta - args[i] isa Active - end) - f = make_active_byref(f, tup) - ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...)) rev_kernel = similar(func.val, rev_cpu)