Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme: simplify via mixedduplicated #483

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
[compat]
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
Atomix = "0.1"
EnzymeCore = "0.7.1"
EnzymeCore = "0.7.5"
InteractiveUtils = "1.6"
LinearAlgebra = "1.6"
MacroTools = "0.5"
Expand Down
34 changes: 2 additions & 32 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,6 @@ module EnzymeExt
fwd_kernel(f, args...; ndrange, workgroupsize)
end


@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
if !any(ActiveTys)
return f
end
function inact(ctx, args2::Vararg{Any, N}) where N
args3 = ntuple(Val(N)) do i
Base.@_inline_meta
if ActiveTys[i]
args2[i][]
else
args2[i]
end
end
f(ctx, args3...)
end
return inact
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
kernel = func.val
f = kernel.f
Expand All @@ -102,11 +83,6 @@ module EnzymeExt
# TODO autodiff_deferred on the func.val
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)
FT = Const{Core.Typeof(f)}

arg_refs = ntuple(Val(N)) do i
Expand All @@ -120,7 +96,7 @@ module EnzymeExt
args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
EnzymeCore.MixedDuplicated(args[i].val, arg_refs[i])
else
args[i]
end
Expand Down Expand Up @@ -150,7 +126,7 @@ module EnzymeExt
args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
EnzymeCore.MixedDuplicated(args[i].val, arg_refs[i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct for the CPU, but I think we won't be able to pass MixedDuplicated to the GPU.
Or at least we need an Adapt rule.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have those already like we do for the other annotations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently not, now we do here: EnzymeAD/Enzyme.jl#1551

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should double check that this adapts to what we want on CUDA. I don't know if CuRef is right here...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but this is code wouldn’t affect any currently landed cuda support (this is only reverse mode).

So we should likely explore that in michel’s PR?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what about non CUDA GPUs?

else
args[i]
end
Expand All @@ -159,12 +135,6 @@ module EnzymeExt
kernel = func.val
f = kernel.f

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)

ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

rev_kernel = similar(func.val, rev_cpu)
Expand Down
Loading