Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DI: Reverse dependency structure #2301

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

DI: Reverse dependency structure #2301

wants to merge 1 commit into from

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented Feb 8, 2025

@gdalle this will enable the DI ext to more properly touch internals if need be

include("reverse_onearg.jl")
include("reverse_twoarg.jl")

end # module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
end # module
end # module

Comment on lines +4 to +16
f::F,
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
return DI.NoPushforwardPrep()
end

function DI.value_and_pushforward(
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F,
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
return DI.NoPushforwardPrep()
end
function DI.value_and_pushforward(
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
f::F,
::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context, C},
) where {F, C}
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context, C},
) where {F, C}

Comment on lines +31 to +37
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context, C},
) where {F, B, C}

Comment on lines +48 to +54
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context, C},
) where {F, C}

Comment on lines +65 to +71
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context, C},
) where {F, B, C}

Comment on lines +12 to +29
f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return f
end

@inline function get_f_and_df(
f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return Const(f)
end

@inline function get_f_and_df(
f::F,
::AutoEnzyme{
M,
<:Union{
Duplicated,
MixedDuplicated,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return f
end
@inline function get_f_and_df(
f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return Const(f)
end
@inline function get_f_and_df(
f::F,
::AutoEnzyme{
M,
<:Union{
Duplicated,
MixedDuplicated,
f::F, ::AutoEnzyme{M, Nothing}, mode::Mode, ::Val{B} = Val(1)
) where {F, M, B}
f::F, ::AutoEnzyme{M, <:Const}, mode::Mode, ::Val{B} = Val(1)
) where {F, M, B}
f::F,
::AutoEnzyme{
M,
<:Union{
Duplicated,
MixedDuplicated,
BatchDuplicated,
BatchMixedDuplicated,
DuplicatedNoNeed,
BatchDuplicatedNoNeed,
},
mode::Mode,
::Val{B} = Val(1),
) where {F, M, B}

Comment on lines +47 to +51
force_annotation(f::F) where {F<:Annotation} = f
force_annotation(f::F) where {F} = Const(f)

@inline function _translate(
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
force_annotation(f::F) where {F<:Annotation} = f
force_annotation(f::F) where {F} = Const(f)
@inline function _translate(
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext}
force_annotation(f::F) where {F <: Annotation} = f
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant, DI.BackendContext}
) where {B}
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
) where {B}

Comment on lines +67 to +70
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
) where {B}
return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B)))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
) where {B}
return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B)))
end
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
) where {B}
backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context, C}
) where {B, C}

Comment on lines +103 to +104
set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode)
set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode)
set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode
set_err(mode::Mode, ::AutoEnzyme{<:Any, Nothing}) = EnzymeCore.set_err_if_func_written(mode)
set_err(mode::Mode, ::AutoEnzyme{<:Any, <:Annotation}) = mode

Comment on lines +118 to +120
function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B}
return BatchDuplicated(x, tx)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B}
return BatchDuplicated(x, tx)
end
function annotate(::Type{BatchDuplicated{T, B}}, x, tx::NTuple{B}) where {T, B}
batchify_activity(::Type{Active{T}}, ::Val{B}) where {T, B} = Active{T}
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T, B} = BatchDuplicated{T, B}

Copy link
Contributor

github-actions bot commented Feb 8, 2025

Benchmark Results

main 00803ee... main/00803ee98e24c7...
basics/overhead 5.26 ± 0.01 ns 4.64 ± 0.01 ns 1.13
time_to_load 1.1 ± 0.012 s 1.13 ± 0.034 s 0.967

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@wsmoses wsmoses marked this pull request as draft February 8, 2025 18:34
return set_err(ReverseSplitWithPrimal, backend)
end

set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle as discussed on slack this probably should be an extension to the set_err_if_func_written function to take an ADmode, so likely we have this in an EnzymeCoreADTypes ext?

forward_withprimal(backend::AutoEnzyme{<:ForwardMode}) = WithPrimal(backend.mode)
forward_withprimal(::AutoEnzyme{Nothing}) = ForwardWithPrimal

reverse_noprimal(backend::AutoEnzyme{<:ReverseMode}) = NoPrimal(backend.mode)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle similarly here we can make an ADTypes ext func for get_mode_or_default(AutoEnzyme, defaultMode)

dy_sametype = convert(typeof(y), only(prep.ty_copy))
x_and_dx = Duplicated(x, dx_sametype)
y_and_dy = Duplicated(y, dy_sametype)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle I presume this can be moved into Enzyme.gradient! And have DI call that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant