Skip to content

Commit

Permalink
Merge pull request #176 from omlins/ad
Browse files Browse the repository at this point in the history
Improve AD module
  • Loading branch information
omlins authored Dec 4, 2024
2 parents c5a945a + fe0c955 commit bb60399
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ParallelStencil_MetalExt = "Metal"
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
CUDA = "3.12, 4, 5"
CellArrays = "0.3"
Enzyme = "0.11, 0.12, 0.13"
Enzyme = "0.12, 0.13"
MacroTools = "0.5"
Metal = "1.2"
Polyester = "0.7"
Expand Down
7 changes: 2 additions & 5 deletions src/AD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
import ParallelStencil.AD
# Functions
- `autodiff_deferred!`: wraps function `autodiff_deferred`.
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`.
- `autodiff_deferred!`: wraps function `autodiff_deferred`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const.
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const.
# Examples
const USE_GPU = true
Expand Down Expand Up @@ -43,9 +43,6 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
main()
!!! note "Enzyme runtime activity default"
If ParallelStencil is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil.
To see a description of a function type `?<functionname>`.
"""
module AD
Expand Down
7 changes: 2 additions & 5 deletions src/ParallelKernel/EnzymeExt/AD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
import ParallelKernel.AD
# Functions
- `autodiff_deferred!`: wraps function `autodiff_deferred`.
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`.
!!! note "Enzyme runtime activity default"
If ParallelKernel is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil.
- `autodiff_deferred!`: wraps function `autodiff_deferred`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const.
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const.
To see a description of a function type `?<functionname>`.
"""
Expand Down
9 changes: 5 additions & 4 deletions src/ParallelKernel/EnzymeExt/autodiff_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ import ParallelStencil
import ParallelStencil: PKG_THREADS, PKG_POLYESTER
import Enzyme

# NOTE: package specific initialization of Enzyme could be done as follows (not needed in the currently supported versions of Enzyme)
# function ParallelStencil.ParallelKernel.AD.init_AD(package::Symbol)
# if iscpu(package)
# Enzyme.API.runtimeActivity!(true) # NOTE: this is currently required for Enzyme to work correctly with threads
# end
# end

# ParallelStencil injects a configuration parameter at the end, for Enzyme we need to wrap that parameter as a Annotation
# for all purposes this ought to be Const. This is not ideal since we might accidentially wrap other parameters the user
# provided as well. This is needed to support @parallel autodiff_deferred(...)
function promote_to_const(args::Vararg{Any,N}) where N
# NOTE: @parallel injects four parameters at the end, which need to be wrapped as Annotations. The current solution is to wrap all
# arguments which are not already Annotations (all the other arguments must be Annotations). Should this change, then one could
# explicitly wrap just the injected parameters.
function promote_to_const(args::Vararg{Any,N}) where N
ntuple(Val(N)) do i
@inline
if !(args[i] isa Enzyme.Annotation ||
Expand Down
4 changes: 2 additions & 2 deletions test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ eval(:(
end
return
end
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a))
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!),Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, f!, Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) # NOTE: f! is automatically promoted to Const.
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!), Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
@test Array(Ā) Ā_ref
@test Array(B̄) B̄_ref
end
Expand Down

0 comments on commit bb60399

Please sign in to comment.