Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful Hand-Written Rules #403

Open
willtebbutt opened this issue Dec 1, 2024 · 7 comments
Open

Stateful Hand-Written Rules #403

willtebbutt opened this issue Dec 1, 2024 · 7 comments
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Dec 1, 2024

At present, if an rrule!! needs to allocate storage for some value it overwrites on the forwards pass, it must do so each time the rule is called. This is distinct from derived rules, which have state which is preserved between calls.

There's no particular reason for things to be this way. This kind of functionality would be helpful, because it would make it possible to trade off some memory in exchange for runtime performance. Take, for example, BLAS.gemm!(transA, transB, alpha, A, B, beta, C) -- this rule has to allocate memory to hold whatever value C has upon entry, in order to restore it on the reverse-pass. Currently, it must allocate this memory each time the rule is called, but if we permit rules to be stateful we can just use the same heuristic that we use for the derived rules, and avoid de-allocating this memory -- subsequent calls to a rule would be fast.

This change could probably be made non-breaking (rrule!!s would remain unchanged). We would just add a function build_primitive_rrule or something, which returns rrule!! by default, but which one can over-ride to return e.g. a custom callable struct instead.

I first realised that this is probably going to be required while addressing #394 . The primal is type-stable and non-allocating, but Mooncake has a great deal of allocations. Now that a couple of performance bugs have been fixed (the generated IR was not type stable due to these bugs), the removal of the remaining allocations will require solving this issue.

@willtebbutt willtebbutt added the enhancement (performance) Would reduce the time it takes to run some bit of the code label Dec 1, 2024
@RoyCCWang
Copy link
Contributor

While I was going through The Rule Interface Round 2 of the current documentation, I wrote the following code that should be equivalent to the coded rrule!! example in that section of the documentation; i.e., rrule_mod!! below should behave the same as the rrule! from the example.

using LinearAlgebra
import Mooncake as MN

T = Float64

# This is from the example, the primal function.
function eval_model(p::Tuple{T, Vector{T}}) where T <: AbstractFloat
    a, b = p[1], p[2]
    return a + sum(b)
end

# This is from the example.
function rrule!!(
    ::MN.CoDual{typeof(eval_model)},
    x::MN.CoDual{
        Tuple{T, Vector{T}},
    },
    ) where T <: AbstractFloat

    dx_fdata = x.dx
    function df_adjoint(dy::T)

        dx_fdata[2] .+= dy
        dx_1_rdata = dy
        dx_rdata = (dx_1_rdata, MN.NoRData())

        return MN.NoRData(), dx_rdata
    end

    x_p = x.x
    return MN.CoDual(x_p[1] + sum(x_p[2]), MN.NoFData()), df_adjoint
end

# My alternatives that I used while exploring the documentation.
struct ReversePassCallable{T}
    aux_state::T # mutates whenever ModelAjoint is called.
end
function (A!::ReversePassCallable)(model_output::AbstractFloat)

    output_state = A!.aux_state[end]
    for i in eachindex(output_state)
        output_state[i] += model_output
    end
    return MN.NoRData(), (model_output, MN.NoRData())
end

function rrule_mod!!(
    ::MN.CoDual{typeof(eval_model)},
    state::MN.CoDual{
        Tuple{T, Vector{T}},
    },
    ) where T <: AbstractFloat

    run_reverse_pass = ReversePassCallable(get_aux(state))
    model_eval = eval_model(get_primal_inputs(state))
    return MN.CoDual(model_eval, MN.NoFData()), run_reverse_pass
end

function get_aux(A::MN.CoDual)
    return A.dx
end

function get_primal_inputs(A::MN.CoDual)
    return A.x
end

D = 2
a = T(5)
b = [T(1); T(2)]
y0 = zeros(D)

out, pb!! = rrule!!(
    MN.CoDual(
        eval_model,
        MN.NoFData(),
    ),
    MN.CoDual(
        (a, b),
        (MN.NoFData(), y0),
    ),
)

out_mod, pb_mod!! = rrule!!(
    MN.CoDual(
        eval_model,
        MN.NoFData(),
    ),
    MN.CoDual(
        (a, b),
        (MN.NoFData(), y0),
    ),
)

Sanity-check: my alternative seems to match the example. My REPL showed:

julia> out
Mooncake.CoDual{Float64, Mooncake.NoFData}(8.0, Mooncake.NoFData())

julia> out_mod
Mooncake.CoDual{Float64, Mooncake.NoFData}(8.0, Mooncake.NoFData())

julia> pb!!(one(T))
(Mooncake.NoRData(), (1.0, Mooncake.NoRData()))

julia> pb_mod!!(one(T))
(Mooncake.NoRData(), (1.0, Mooncake.NoRData()))

Did you mean something like the ReversePassCallable as a first possible approach to the automatic generation of stateful rrules!! callables?

@willtebbutt
Copy link
Member Author

willtebbutt commented Dec 2, 2024

This is close to what I had in mind, but I don't think it quite captures the problem, because eval_model does not appear to mutate its inputs.

Here's concrete example which involves over-writing memory, thus necessitating restoring it afterwards. Consider writing a rule for signature

Tuple(typeof(mul!), Matrix{P}, Matrix{P}, Matrix{P}} where {P<:IEEEFloat}

whose semantics are (as usual) to overwrite the first matrix argument with the result of multiplying the second and third matrices.

The way to do it using rrule!! would be something like

using BenchmarkTools
using Mooncake
using Mooncake: NoRData, CoDual, zero_fcodual
using Base: IEEEFloat
using LinearAlgebra: mul!

function Mooncake.rrule!!(
    ::CoDual{typeof(mul!)}, C::CoDual{Matrix{P}}, A::CoDual{Matrix{P}}, B::CoDual{Matrix{P}}
) where {P<:IEEEFloat}

    # Make a copy of `C` and its adjoint. This is where allocations are introduced.
    C_copy = copy(C.x)
    dC_copy = copy(C.dx)

    # Run the forwards-pass.
    mul!(C.x, A.x, B.x)
    C.dx .= zero(P)

    function pb!!(::NoRData)
        # Do the computations needed to increment tangents of A and B.
        # code to increment A.dx
        # code to increment B.dx

        # Reset value of `C`.
        copy!(C.x, C_copy)
        copy!(C.dx, dC_copy)

        return NoRData(), NoRData(), NoRData(), NoRData()
    end
    return C, pb!!
end

A stateful version of this might be something along the lines of

# mutable struct which might be uninitialised. We would probably insist that this struct be
# used by all stateful rules.
mutable struct StatefulRRule{Tstate}
    state::Tstate
    StatefulRRule{Tstate}() where {Tstate} = new{Tstate}()
end

# Create a mutable struct with uninitialised state. It should be possible to add a macro, in
# the same vein as `@is_primitive`, to make this function simpler to add methods to.
function build_primitive_rrule(
    ::Type{<:Tuple{typeof(mul!), Matrix{P}, Matrix{P}, Matrix{P}}}
) where {P<:IEEEFloat}
    return StatefulRRule{Tuple{Matrix{P}, Matrix{P}}}()
end

# Note: the only difference between this signature and the one for the `rrule!!` above is the function itself.
function (rule::StatefulRRule)(
    ::CoDual{typeof(mul!)}, C::CoDual{Matrix{P}}, A::CoDual{Matrix{P}}, B::CoDual{Matrix{P}}
) where {P<:IEEEFloat}

    # If we don't already have some state allocated, allocate some. After this, the
    # remainder of the function is identical to the rrule!!.
    if !isdefined(rule, :state)
        rule.state = (copy(C.x), copy(C.dx))
    end

    # We can be sure that we have state in the rule at this point, so just make use of it.
    C_copy = rule.state[1]
    dC_copy = rule.state[2]
    copy!(rule.state[1], C.x)
    copy!(rule.state[2], C.dx)

    # Run the forwards-pass.
    mul!(C.x, A.x, B.x)
    C.dx .= zero(P)

    # The pullback can close over the `StatefulRule`.
    function pb!!(::NoRData)
        # Do the computations needed to increment tangents of A and B.
        # code to increment A.dx
        # code to increment B.dx

        # Reset value of `C`.
        copy!(C.x, C_copy)
        copy!(C.dx, dC_copy)

        return NoRData(), NoRData(), NoRData(), NoRData()
    end
    return C, pb!!
end

sig = Tuple{typeof(mul!), Matrix{Float64}, Matrix{Float64}, Matrix{Float64}};
stateful_rule = build_primitive_rrule(sig);

Observe that the first time that a stateful rule is called it allocates the memory needed. On subsequent visits, it will just reuse the memory. As a result, we get the following timing results:

julia> C, A, B = randn(16, 8), randn(16, 32), randn(32, 8);

julia> @btime mul!($C, $A, $B);
  289.234 ns (0 allocations: 0 bytes)

julia> _C, _A, _B = zero_fcodual(C), zero_fcodual(A), zero_fcodual(B);

julia> @btime Mooncake.rrule!!(zero_fcodual(mul!), $_C, $_A, $_B)[2](NoRData());
  395.835 ns (4 allocations: 2.22 KiB)

julia> @btime ($stateful_rule)(zero_fcodual(mul!), $_C, $_A, $_B)[2](NoRData());
  358.533 ns (0 allocations: 0 bytes)

Note that the implementation of the StatefulRule for sig above is incomplete, because it cannot handle

  1. repeated calls in the forwards pass -- we would need to make use of Mooncake.Stack of Tuples in the state, not just a Tuple, and
  2. changes in the size of the matrices passes in -- if the input dimensions change from call-to-call, the cache size will need to change.

These points are just to say that some thought is required regarding the caching structure, but hopefully they don't detract from the larger point: introducing a StatefulRRule type should enable us to (with some careful design) eliminate allocations in the forwards pass in exchange for increasing peak memory usage.

@willtebbutt
Copy link
Member Author

Note that another benefit of permitting stateful rules would be that we could properly implement ChainRules.jl's call-back-into-AD mechanism properly. Currently it's not really possible to do, because our AD requires carrying some state around in order to get optimal performance, but our rrule!!s do not admit state.

@RoyCCWang
Copy link
Contributor

Thanks for the explanation! Please post additional tips or ideas here because I'm interested to explore this issue more after my documentation PR.

@willtebbutt
Copy link
Member Author

willtebbutt commented Jan 27, 2025

#437 made progress on this possible. What is needed now is a systematic audit of hand-written rules, to find examples where the primal does not allocate but the rrule!! does, and convert them into stateful rules.

We need to do this systematically if we want to ensure that this problem permanently disappears, so I propose the following:

  1. add a test to
    function test_rrule_performance(
    performance_checks_flag::Symbol, rule::R, f_f̄::F, x_x̄::Vararg{Any,N}
    ) where {R,F,N}
    # Verify that a valid performance flag has been passed.
    valid_flags = (:none, :stability, :allocs, :stability_and_allocs)
    if !in(performance_checks_flag, valid_flags)
    throw(
    ArgumentError(
    "performance_checks=$performance_checks_flag. Must be one of $valid_flags"
    ),
    )
    end
    performance_checks_flag == :none && return nothing
    if performance_checks_flag in (:stability, :stability_and_allocs)
    # Test primal stability.
    test_opt(Shim(), primal(f_f̄), map(_typeof primal, x_x̄))
    # Test forwards-pass stability.
    test_opt(Shim(), rule, (_typeof(to_fwds(f_f̄)), map(_typeof to_fwds, x_x̄)...))
    # Test reverse-pass stability.
    y_ȳ, pb!! = rule(to_fwds(f_f̄), map(to_fwds, _deepcopy(x_x̄))...)
    rvs_data = Mooncake.rdata(zero_tangent(primal(y_ȳ), tangent(y_ȳ)))
    test_opt(Shim(), pb!!, (_typeof(rvs_data),))
    end
    if performance_checks_flag in (:allocs, :stability_and_allocs)
    f = primal(f_f̄)
    x = map(primal, x_x̄)
    # Test allocations in primal.
    f(x...)
    @test (@allocations f(x...)) == 0
    # Test allocations in round-trip.
    f_f̄_fwds = to_fwds(f_f̄)
    x_x̄_fwds = map(to_fwds, x_x̄)
    __forwards_and_backwards(rule, f_f̄_fwds, x_x̄_fwds...)
    @test (@allocations __forwards_and_backwards(rule, f_f̄_fwds, x_x̄_fwds...)) == 0
    end
    end
    which fails if the primal does not allocate but the rule does, and permit this to be restricted to only rules for which is_primitive is true,
  2. use this test to identify all hand-written rules which need to be made stateful, and
  3. re-write the rules.

This should all be done in a single PR, as the changes to any individual rule are going to be quite minor -- it will also enable us to figure out what good abstractions look like for this. In particular, we should definitely add some abstraction to make it quite hard to accidentally fail to implement these stateful rules incorrectly by e.g. failing to account for the fact that the same rule can be called multiple times on the forwards-pass. Exactly what the correct solution will look like here is a little bit unclear to me, but we almost certainly do not want rule-writers to have to manage pushing / popping stacks themselves.

edit:

Here's a proposal for an abstraction that could work.
The goal here is to prevent rule-writers from ever having to manage stacks themselves, as doing so opens them up to writing rules which get state-management wrong, and lead to hard-to-debug correctness problems.
It's worth being quite pedantic to ensure that we minimise the chance of this happening by-design, as testing for it in any thorough way is going to be quite tricky.

The proposal: we define a generic callable struct

struct StatefulRRule{sig, Tstate}
    states::Stack{Tstate}
end

# This function implements Mooncake's rule interface.
function (rule::StatefulRRule)(args::CoDual...)
    state = length(rule.stack) == rule.stack.position ? nothing : rule.stack[rule.stack.position]
    primal_fdata_pair, pb!!, new_state = stateful_rrule!!(state, args...) # user will have to define a method of this function
    push!(rule.stack, new_state)
    function stateful_pb!!(dy)
        pop!(rule.states)
        return pb!!(dy)
    end
    return primal_fdata_pair, stateful_pb!!
end

Additionally, we define a macro which is called as follows:

@stateful_primitive context_type sig state_type

which expands to

@is_primitive context_type sig
function Mooncake.build_primitive_rrule(::Type{sig}) where {sig<:Tuple}
    return StatefulRRule{sig,state_type}(Stack{state_type}())
end

i.e. the macro declares the given signature to be a primitive via the usual mechanism, but also adds a method to build_primitive_rrule which constructs an instance of StatefulRRule.

A rule-writer is required to implement a method of stateful_rrule!!.
This function has similar semantics to rrule!! but

  1. receives an additional argument at the start, which is the state that the StatefulRRule thinks is likely to be useful. If this is the first time we're calling the stateful rule, then nothing will be passed in.
  2. the user must return three arguments. The first two are the same as what rrule!! returns, while the third is the updated state.

Most of the time, the new_state returned from stateful_rrule!! and the state passed in will be ===. i.e. the rule should typically only modify the contents of the state.
Exceptions to this are the first time that this rule is called, when the initial value of the state will be nothing and the rule author has to allocate the memory for the first time, and if control flow changes the order in which things are called to mean that memory has to be re-allocated.

So a rule-writer will typically write something like the following:

@stateful_primitive context_type sig state_type
function stateful_rrule!!(state, args::CoDual...)

    # Handle state.
    if state === nothing
        new_state = <allocate stuff>
    else
        <check state is of the correct size etc, and re-allocate if it is not>
    end

    # Run rule, exactly as before, modifying `new_state`.

    return y, pb!!, new_state
end

In terms of testing, we can tweak our test_rule function slightly to ensure that this interface has been implemented correctly (i.e. that a method of stateful_rrule!! has been implemented, that it returns things of the correct type, and that it runs when state is either nothing or an instance of state_type), and then apply the rest of the test suite to check correctness and performance in the usual manner.

Note that this design ensures that rule-writers never see any state management, they just get given whatever the current state is, are free to do with it whatever they like, and just have to return a new version of state at the end of the function call, in addition to what they usually return.

@RoyCCWang
Copy link
Contributor

RoyCCWang commented Feb 3, 2025

Good stuff!

For the latest proposed approach, would it be a problem if the developer manually allocates non-escaping temporary storage buffers via malloc and free in the implementation of the stateful function to be differentiated?

I personally do not directly call malloc inside mutating functions, but I wonder if it is safe to call a mutating function from another package knowing that it does use routines from PtrArrays.jl or Bumper.jl that calls malloc under the hood.

@willtebbutt
Copy link
Member Author

This is a great question. In short: I'm not sure. In principle I can't see a problem with this kind of thing, but we would need to give some specific examples some though in order to be sure.

I've also not yet written rules for malloc and free, but I'm fairly sure that they wouldn't pose too much trouble.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code
Projects
None yet
Development

No branches or pull requests

3 participants