diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index cb6bcd349..fc69bb5b4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -25,6 +25,9 @@ using Mooncake: DI.check_available(::AutoMooncake) = true +copyto!!(dst::Number, src::Number) = convert(typeof(dst), src) +copyto!!(dst, src) = copyto!(dst, src) + get_config(::AutoMooncake{Nothing}) = Config() get_config(backend::AutoMooncake{<:Config}) = backend.config diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 02c04bb22..8157b34bc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,6 +1,8 @@ -struct MooncakeOneArgPullbackPrep{Y,R} <: PullbackPrep +struct MooncakeOneArgPullbackPrep{Y,R,DX,DY} <: PullbackPrep y_prototype::Y rrule::R + dx_righttype::DX + dy_righttype::DY end function DI.prepare_pullback( @@ -14,7 +16,9 @@ function DI.prepare_pullback( debug_mode=config.debug_mode, silence_debug_messages=config.silence_debug_messages, ) - prep = MooncakeOneArgPullbackPrep(y, rrule) + dx_righttype = zero_tangent(x) + dy_righttype = zero_tangent(y) + prep = MooncakeOneArgPullbackPrep(y, rrule, dx_righttype, dy_righttype) DI.value_and_pullback(f, prep, backend, x, ty, contexts...) # warm up return prep end @@ -28,7 +32,7 @@ function DI.value_and_pullback( contexts::Vararg{Context,C}, ) where {Y,C} dy = only(ty) - dy_righttype = convert(tangent_type(Y), dy) + dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( prep.rrule, dy_righttype, f, x, map(unwrap, contexts)... ) @@ -45,8 +49,8 @@ function DI.value_and_pullback!( contexts::Vararg{Context,C}, ) where {Y,C} dx, dy = only(tx), only(ty) - dy_righttype = convert(tangent_type(Y), dy) - dx_righttype = set_to_zero!!(convert(tangent_type(typeof(x)), dx)) + dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) + dx_righttype = set_to_zero!!(prep.dx_righttype) contexts_coduals = map(zero_codual ∘ unwrap, contexts) y, (_, new_dx) = __value_and_pullback!!( prep.rrule, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index c0fbbc406..f1c158435 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,5 +1,10 @@ -struct MooncakeTwoArgPullbackPrep{R} <: PullbackPrep +struct MooncakeTwoArgPullbackPrep{R,F,Y,DX,DY} <: PullbackPrep rrule::R + df!::F + y_copy::Y + dx_righttype::DX + dy_righttype::DY + dy_righttype_after::DY end function DI.prepare_pullback( @@ -12,7 +17,14 @@ function DI.prepare_pullback( debug_mode=config.debug_mode, silence_debug_messages=config.silence_debug_messages, ) - prep = MooncakeTwoArgPullbackPrep(rrule) + df! = zero_tangent(f!) + y_copy = copy(y) + dx_righttype = zero_tangent(x) + dy_righttype = zero_tangent(y) + dy_righttype_after = zero_tangent(y) + prep = MooncakeTwoArgPullbackPrep( + rrule, df!, y_copy, dx_righttype, dy_righttype, dy_righttype_after + ) DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) # warm up return prep end @@ -27,29 +39,21 @@ function DI.value_and_pullback( contexts::Vararg{Context,C}, ) where {C} dy = only(ty) - dy_righttype = convert(tangent_type(typeof(y)), copy(dy)) - dx_righttype = zero_tangent(x) - # We want the VJP, not VJP + dx, so I'm going to zero-out `dx`. `set_to_zero!!` has the advantage - # that it will also replace any immutable components of `dx` to zero. - dx_righttype = set_to_zero!!(dx_righttype) + # Set all tangent storage to zero. + df! = set_to_zero!!(prep.df!) + dx_righttype = set_to_zero!!(prep.dx_righttype) + dy_righttype = set_to_zero!!(prep.dy_righttype) - # We want `dy` to correspond to the cotangent of `y` _after_ - # running the forwards-pass, so I'm going to take a copy, and zero-out the original. - dy_righttype_backup = copy(dy_righttype) - dy_righttype = set_to_zero!!(dy_righttype) - contexts_coduals = map(zero_fcodual ∘ unwrap, contexts) - - # Mutate a copy of `y`, so that we can run the reverse-pass later on. - y_copy = copy(y) + # Prepare cotangent to add after the forward pass. + dy_righttype_after = copyto!(prep.dy_righttype_after, dy) - # In case `f!` is a closure - df! = zero_tangent(f!) + contexts_coduals = map(zero_fcodual ∘ unwrap, contexts) - # Run the forwards-pass. + # Run the forward pass out, pb!! = prep.rrule( CoDual(f!, fdata(df!)), - CoDual(y_copy, fdata(dy_righttype)), + CoDual(prep.y_copy, fdata(dy_righttype)), CoDual(x, fdata(dx_righttype)), contexts_coduals..., ) @@ -57,16 +61,16 @@ function DI.value_and_pullback( # Verify that the output is non-differentiable. @assert primal(out) === nothing - # Set the cotangent of `y` to be equal to the requested value. - dy_righttype = increment!!(dy_righttype, dy_righttype_backup) + # Increment the desired cotangent dy. + dy_righttype = increment!!(dy_righttype, dy_righttype_after) - # Record the state of `y` before running the reverse-pass. - y = copyto!(y, y_copy) + # Record the state of y before running the reverse pass. + y = copyto!(y, prep.y_copy) - # Run the reverse-pass. + # Run the reverse pass. _, _, new_dx = pb!!(NoRData()) - return y, (tangent(fdata(dx_righttype), new_dx),) + return y, (tangent(copy(fdata(dx_righttype)), new_dx),) # TODO: remove this allocation in `value_and_pullback!` end function DI.value_and_pullback(