Skip to content

Commit

Permalink
Improve Mooncake caching (#513)
Browse files Browse the repository at this point in the history
* Improve Mooncake caching

* Fix
  • Loading branch information
gdalle authored Sep 30, 2024
1 parent 3f29b61 commit 7ee5859
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ using Mooncake:

DI.check_available(::AutoMooncake) = true

copyto!!(dst::Number, src::Number) = convert(typeof(dst), src)
copyto!!(dst, src) = copyto!(dst, src)

get_config(::AutoMooncake{Nothing}) = Config()
get_config(backend::AutoMooncake{<:Config}) = backend.config

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
struct MooncakeOneArgPullbackPrep{Y,R} <: PullbackPrep
struct MooncakeOneArgPullbackPrep{Y,R,DX,DY} <: PullbackPrep
y_prototype::Y
rrule::R
dx_righttype::DX
dy_righttype::DY
end

function DI.prepare_pullback(
Expand All @@ -14,7 +16,9 @@ function DI.prepare_pullback(
debug_mode=config.debug_mode,
silence_debug_messages=config.silence_debug_messages,
)
prep = MooncakeOneArgPullbackPrep(y, rrule)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
prep = MooncakeOneArgPullbackPrep(y, rrule, dx_righttype, dy_righttype)
DI.value_and_pullback(f, prep, backend, x, ty, contexts...) # warm up
return prep
end
Expand All @@ -28,7 +32,7 @@ function DI.value_and_pullback(
contexts::Vararg{Context,C},
) where {Y,C}
dy = only(ty)
dy_righttype = convert(tangent_type(Y), dy)
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
new_y, (_, new_dx) = value_and_pullback!!(
prep.rrule, dy_righttype, f, x, map(unwrap, contexts)...
)
Expand All @@ -45,8 +49,8 @@ function DI.value_and_pullback!(
contexts::Vararg{Context,C},
) where {Y,C}
dx, dy = only(tx), only(ty)
dy_righttype = convert(tangent_type(Y), dy)
dx_righttype = set_to_zero!!(convert(tangent_type(typeof(x)), dx))
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
dx_righttype = set_to_zero!!(prep.dx_righttype)
contexts_coduals = map(zero_codual unwrap, contexts)
y, (_, new_dx) = __value_and_pullback!!(
prep.rrule,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
struct MooncakeTwoArgPullbackPrep{R} <: PullbackPrep
struct MooncakeTwoArgPullbackPrep{R,F,Y,DX,DY} <: PullbackPrep
rrule::R
df!::F
y_copy::Y
dx_righttype::DX
dy_righttype::DY
dy_righttype_after::DY
end

function DI.prepare_pullback(
Expand All @@ -12,7 +17,14 @@ function DI.prepare_pullback(
debug_mode=config.debug_mode,
silence_debug_messages=config.silence_debug_messages,
)
prep = MooncakeTwoArgPullbackPrep(rrule)
df! = zero_tangent(f!)
y_copy = copy(y)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
dy_righttype_after = zero_tangent(y)
prep = MooncakeTwoArgPullbackPrep(
rrule, df!, y_copy, dx_righttype, dy_righttype, dy_righttype_after
)
DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) # warm up
return prep
end
Expand All @@ -27,46 +39,38 @@ function DI.value_and_pullback(
contexts::Vararg{Context,C},
) where {C}
dy = only(ty)
dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
dx_righttype = zero_tangent(x)

# We want the VJP, not VJP + dx, so I'm going to zero-out `dx`. `set_to_zero!!` has the advantage
# that it will also replace any immutable components of `dx` to zero.
dx_righttype = set_to_zero!!(dx_righttype)
# Set all tangent storage to zero.
df! = set_to_zero!!(prep.df!)
dx_righttype = set_to_zero!!(prep.dx_righttype)
dy_righttype = set_to_zero!!(prep.dy_righttype)

# We want `dy` to correspond to the cotangent of `y` _after_
# running the forwards-pass, so I'm going to take a copy, and zero-out the original.
dy_righttype_backup = copy(dy_righttype)
dy_righttype = set_to_zero!!(dy_righttype)
contexts_coduals = map(zero_fcodual unwrap, contexts)

# Mutate a copy of `y`, so that we can run the reverse-pass later on.
y_copy = copy(y)
# Prepare cotangent to add after the forward pass.
dy_righttype_after = copyto!(prep.dy_righttype_after, dy)

# In case `f!` is a closure
df! = zero_tangent(f!)
contexts_coduals = map(zero_fcodual unwrap, contexts)

# Run the forwards-pass.
# Run the forward pass
out, pb!! = prep.rrule(
CoDual(f!, fdata(df!)),
CoDual(y_copy, fdata(dy_righttype)),
CoDual(prep.y_copy, fdata(dy_righttype)),
CoDual(x, fdata(dx_righttype)),
contexts_coduals...,
)

# Verify that the output is non-differentiable.
@assert primal(out) === nothing

# Set the cotangent of `y` to be equal to the requested value.
dy_righttype = increment!!(dy_righttype, dy_righttype_backup)
# Increment the desired cotangent dy.
dy_righttype = increment!!(dy_righttype, dy_righttype_after)

# Record the state of `y` before running the reverse-pass.
y = copyto!(y, y_copy)
# Record the state of y before running the reverse pass.
y = copyto!(y, prep.y_copy)

# Run the reverse-pass.
# Run the reverse pass.
_, _, new_dx = pb!!(NoRData())

return y, (tangent(fdata(dx_righttype), new_dx),)
return y, (tangent(copy(fdata(dx_righttype)), new_dx),) # TODO: remove this allocation in `value_and_pullback!`
end

function DI.value_and_pullback(
Expand Down

0 comments on commit 7ee5859

Please sign in to comment.