Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Mooncake caching #513

Merged
merged 3 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ using Mooncake:

DI.check_available(::AutoMooncake) = true

copyto!!(dst::Number, src::Number) = convert(typeof(dst), src)
copyto!!(dst, src) = copyto!(dst, src)

get_config(::AutoMooncake{Nothing}) = Config()
get_config(backend::AutoMooncake{<:Config}) = backend.config

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
struct MooncakeOneArgPullbackPrep{Y,R} <: PullbackPrep
struct MooncakeOneArgPullbackPrep{Y,R,DX,DY} <: PullbackPrep
y_prototype::Y
rrule::R
dx_righttype::DX
dy_righttype::DY
end

function DI.prepare_pullback(
Expand All @@ -14,7 +16,9 @@ function DI.prepare_pullback(
debug_mode=config.debug_mode,
silence_debug_messages=config.silence_debug_messages,
)
prep = MooncakeOneArgPullbackPrep(y, rrule)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
prep = MooncakeOneArgPullbackPrep(y, rrule, dx_righttype, dy_righttype)
DI.value_and_pullback(f, prep, backend, x, ty, contexts...) # warm up
return prep
end
Expand All @@ -28,7 +32,7 @@ function DI.value_and_pullback(
contexts::Vararg{Context,C},
) where {Y,C}
dy = only(ty)
dy_righttype = convert(tangent_type(Y), dy)
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
new_y, (_, new_dx) = value_and_pullback!!(
prep.rrule, dy_righttype, f, x, map(unwrap, contexts)...
)
Expand All @@ -45,8 +49,8 @@ function DI.value_and_pullback!(
contexts::Vararg{Context,C},
) where {Y,C}
dx, dy = only(tx), only(ty)
dy_righttype = convert(tangent_type(Y), dy)
dx_righttype = set_to_zero!!(convert(tangent_type(typeof(x)), dx))
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
dx_righttype = set_to_zero!!(prep.dx_righttype)
contexts_coduals = map(zero_codual ∘ unwrap, contexts)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
y, (_, new_dx) = __value_and_pullback!!(
prep.rrule,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
struct MooncakeTwoArgPullbackPrep{R} <: PullbackPrep
struct MooncakeTwoArgPullbackPrep{R,F,Y,DX,DY} <: PullbackPrep
rrule::R
df!::F
y_copy::Y
dx_righttype::DX
dy_righttype::DY
dy_righttype_after::DY
end

function DI.prepare_pullback(
Expand All @@ -12,7 +17,14 @@ function DI.prepare_pullback(
debug_mode=config.debug_mode,
silence_debug_messages=config.silence_debug_messages,
)
prep = MooncakeTwoArgPullbackPrep(rrule)
df! = zero_tangent(f!)
y_copy = copy(y)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
dy_righttype_after = zero_tangent(y)
prep = MooncakeTwoArgPullbackPrep(
rrule, df!, y_copy, dx_righttype, dy_righttype, dy_righttype_after
)
DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) # warm up
return prep
end
Expand All @@ -27,43 +39,36 @@ function DI.value_and_pullback(
contexts::Vararg{Context,C},
) where {C}
dy = only(ty)
dy_righttype = convert(tangent_type(typeof(y)), copy(dy))

# Set all tangent storage to zero.
df! = set_to_zero!!(prep.df!)
# dx_righttype = set_to_zero!!(prep.dx_righttype) # TODO: why doesn't this work?
dx_righttype = zero_tangent(x)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
dy_righttype = set_to_zero!!(prep.dy_righttype)

# We want the VJP, not VJP + dx, so I'm going to zero-out `dx`. `set_to_zero!!` has the advantage
# that it will also replace any immutable components of `dx` to zero.
dx_righttype = set_to_zero!!(dx_righttype)
# Prepare cotangent to add after the forward pass.
dy_righttype_after = copyto!(prep.dy_righttype_after, dy)

# We want `dy` to correspond to the cotangent of `y` _after_
# running the forwards-pass, so I'm going to take a copy, and zero-out the original.
dy_righttype_backup = copy(dy_righttype)
dy_righttype = set_to_zero!!(dy_righttype)
contexts_coduals = map(zero_fcodual ∘ unwrap, contexts)

# Mutate a copy of `y`, so that we can run the reverse-pass later on.
y_copy = copy(y)

# In case `f!` is a closure
df! = zero_tangent(f!)

# Run the forwards-pass.
# Run the forward pass
out, pb!! = prep.rrule(
CoDual(f!, fdata(df!)),
CoDual(y_copy, fdata(dy_righttype)),
CoDual(prep.y_copy, fdata(dy_righttype)),
CoDual(x, fdata(dx_righttype)),
contexts_coduals...,
)

# Verify that the output is non-differentiable.
@assert primal(out) === nothing

# Set the cotangent of `y` to be equal to the requested value.
dy_righttype = increment!!(dy_righttype, dy_righttype_backup)
# Increment the desired cotangent dy.
dy_righttype = increment!!(dy_righttype, dy_righttype_after)

# Record the state of `y` before running the reverse-pass.
y = copyto!(y, y_copy)
# Record the state of y before running the reverse pass.
y = copyto!(y, prep.y_copy)

# Run the reverse-pass.
# Run the reverse pass.
_, _, new_dx = pb!!(NoRData())

return y, (tangent(fdata(dx_righttype), new_dx),)
Expand Down
Loading