-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DI: Reverse dependency structure #2301
base: main
Are you sure you want to change the base?
Conversation
include("reverse_onearg.jl") | ||
include("reverse_twoarg.jl") | ||
|
||
end # module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
end # module | |
end # module |
f::F, | ||
::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | ||
x, | ||
tx::NTuple, | ||
contexts::Vararg{DI.Context,C}, | ||
) where {F,C} | ||
return DI.NoPushforwardPrep() | ||
end | ||
|
||
function DI.value_and_pushforward( | ||
f::F, | ||
::DI.NoPushforwardPrep, | ||
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f::F, | |
::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
x, | |
tx::NTuple, | |
contexts::Vararg{DI.Context,C}, | |
) where {F,C} | |
return DI.NoPushforwardPrep() | |
end | |
function DI.value_and_pushforward( | |
f::F, | |
::DI.NoPushforwardPrep, | |
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
f::F, | |
::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
x, | |
tx::NTuple, | |
contexts::Vararg{DI.Context, C}, | |
) where {F, C} | |
f::F, | |
::DI.NoPushforwardPrep, | |
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
x, | |
tx::NTuple{1}, | |
contexts::Vararg{DI.Context, C}, | |
) where {F, C} |
f::F, | ||
::DI.NoPushforwardPrep, | ||
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | ||
x, | ||
tx::NTuple{B}, | ||
contexts::Vararg{DI.Context,C}, | ||
) where {F,B,C} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f::F, | |
::DI.NoPushforwardPrep, | |
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
x, | |
tx::NTuple{B}, | |
contexts::Vararg{DI.Context,C}, | |
) where {F,B,C} | |
f::F, | |
::DI.NoPushforwardPrep, | |
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
x, | |
tx::NTuple{B}, | |
contexts::Vararg{DI.Context, C}, | |
) where {F, B, C} |
f::F, | ||
::DI.NoPushforwardPrep, | ||
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | ||
x, | ||
tx::NTuple{1}, | ||
contexts::Vararg{DI.Context,C}, | ||
) where {F,C} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f::F, | |
::DI.NoPushforwardPrep, | |
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
x, | |
tx::NTuple{1}, | |
contexts::Vararg{DI.Context,C}, | |
) where {F,C} | |
f::F, | |
::DI.NoPushforwardPrep, | |
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
x, | |
tx::NTuple{1}, | |
contexts::Vararg{DI.Context, C}, | |
) where {F, C} |
f::F, | ||
::DI.NoPushforwardPrep, | ||
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | ||
x, | ||
tx::NTuple{B}, | ||
contexts::Vararg{DI.Context,C}, | ||
) where {F,B,C} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f::F, | |
::DI.NoPushforwardPrep, | |
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
x, | |
tx::NTuple{B}, | |
contexts::Vararg{DI.Context,C}, | |
) where {F,B,C} | |
f::F, | |
::DI.NoPushforwardPrep, | |
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
x, | |
tx::NTuple{B}, | |
contexts::Vararg{DI.Context, C}, | |
) where {F, B, C} |
f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) | ||
) where {F,M,B} | ||
return f | ||
end | ||
|
||
@inline function get_f_and_df( | ||
f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) | ||
) where {F,M,B} | ||
return Const(f) | ||
end | ||
|
||
@inline function get_f_and_df( | ||
f::F, | ||
::AutoEnzyme{ | ||
M, | ||
<:Union{ | ||
Duplicated, | ||
MixedDuplicated, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) | |
) where {F,M,B} | |
return f | |
end | |
@inline function get_f_and_df( | |
f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) | |
) where {F,M,B} | |
return Const(f) | |
end | |
@inline function get_f_and_df( | |
f::F, | |
::AutoEnzyme{ | |
M, | |
<:Union{ | |
Duplicated, | |
MixedDuplicated, | |
f::F, ::AutoEnzyme{M, Nothing}, mode::Mode, ::Val{B} = Val(1) | |
) where {F, M, B} | |
f::F, ::AutoEnzyme{M, <:Const}, mode::Mode, ::Val{B} = Val(1) | |
) where {F, M, B} | |
f::F, | |
::AutoEnzyme{ | |
M, | |
<:Union{ | |
Duplicated, | |
MixedDuplicated, | |
BatchDuplicated, | |
BatchMixedDuplicated, | |
DuplicatedNoNeed, | |
BatchDuplicatedNoNeed, | |
}, | |
mode::Mode, | |
::Val{B} = Val(1), | |
) where {F, M, B} |
force_annotation(f::F) where {F<:Annotation} = f | ||
force_annotation(f::F) where {F} = Const(f) | ||
|
||
@inline function _translate( | ||
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
force_annotation(f::F) where {F<:Annotation} = f | |
force_annotation(f::F) where {F} = Const(f) | |
@inline function _translate( | |
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} | |
force_annotation(f::F) where {F <: Annotation} = f | |
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant, DI.BackendContext} | |
) where {B} | |
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache | |
) where {B} |
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext | ||
) where {B} | ||
return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B))) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext | |
) where {B} | |
return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B))) | |
end | |
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext | |
) where {B} | |
backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context, C} | |
) where {B, C} |
set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) | ||
set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) | |
set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode | |
set_err(mode::Mode, ::AutoEnzyme{<:Any, Nothing}) = EnzymeCore.set_err_if_func_written(mode) | |
set_err(mode::Mode, ::AutoEnzyme{<:Any, <:Annotation}) = mode |
function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B} | ||
return BatchDuplicated(x, tx) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B} | |
return BatchDuplicated(x, tx) | |
end | |
function annotate(::Type{BatchDuplicated{T, B}}, x, tx::NTuple{B}) where {T, B} | |
batchify_activity(::Type{Active{T}}, ::Val{B}) where {T, B} = Active{T} | |
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T, B} = BatchDuplicated{T, B} |
Benchmark Results
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
return set_err(ReverseSplitWithPrimal, backend) | ||
end | ||
|
||
set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gdalle as discussed on slack this probably should be an extension to the set_err_if_func_written function to take an ADmode, so likely we have this in an EnzymeCoreADTypes ext?
forward_withprimal(backend::AutoEnzyme{<:ForwardMode}) = WithPrimal(backend.mode) | ||
forward_withprimal(::AutoEnzyme{Nothing}) = ForwardWithPrimal | ||
|
||
reverse_noprimal(backend::AutoEnzyme{<:ReverseMode}) = NoPrimal(backend.mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gdalle similarly here we can make an ADTypes ext func for get_mode_or_default(AutoEnzyme, defaultMode)
dy_sametype = convert(typeof(y), only(prep.ty_copy)) | ||
x_and_dx = Duplicated(x, dx_sametype) | ||
y_and_dy = Duplicated(y, dy_sametype) | ||
annotated_contexts = translate(backend, mode, Val(1), contexts...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gdalle I presume this can be moved into Enzyme.gradient! And have DI call that?
@gdalle this will enable the DI ext to more properly touch internals if need be