-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stateful Hand-Written Rules #403
Comments
While I was going through The Rule Interface Round 2 of the current documentation, I wrote the following code that should be equivalent to the coded
Sanity-check: my alternative seems to match the example. My REPL showed:
Did you mean something like the |
This is close to what I had in mind, but I don't think it quite captures the problem, because Here's concrete example which involves over-writing memory, thus necessitating restoring it afterwards. Consider writing a rule for signature Tuple(typeof(mul!), Matrix{P}, Matrix{P}, Matrix{P}} where {P<:IEEEFloat} whose semantics are (as usual) to overwrite the first matrix argument with the result of multiplying the second and third matrices. The way to do it using using BenchmarkTools
using Mooncake
using Mooncake: NoRData, CoDual, zero_fcodual
using Base: IEEEFloat
using LinearAlgebra: mul!
function Mooncake.rrule!!(
::CoDual{typeof(mul!)}, C::CoDual{Matrix{P}}, A::CoDual{Matrix{P}}, B::CoDual{Matrix{P}}
) where {P<:IEEEFloat}
# Make a copy of `C` and its adjoint. This is where allocations are introduced.
C_copy = copy(C.x)
dC_copy = copy(C.dx)
# Run the forwards-pass.
mul!(C.x, A.x, B.x)
C.dx .= zero(P)
function pb!!(::NoRData)
# Do the computations needed to increment tangents of A and B.
# code to increment A.dx
# code to increment B.dx
# Reset value of `C`.
copy!(C.x, C_copy)
copy!(C.dx, dC_copy)
return NoRData(), NoRData(), NoRData(), NoRData()
end
return C, pb!!
end A stateful version of this might be something along the lines of # mutable struct which might be uninitialised. We would probably insist that this struct be
# used by all stateful rules.
mutable struct StatefulRRule{Tstate}
state::Tstate
StatefulRRule{Tstate}() where {Tstate} = new{Tstate}()
end
# Create a mutable struct with uninitialised state. It should be possible to add a macro, in
# the same vein as `@is_primitive`, to make this function simpler to add methods to.
function build_primitive_rrule(
::Type{<:Tuple{typeof(mul!), Matrix{P}, Matrix{P}, Matrix{P}}}
) where {P<:IEEEFloat}
return StatefulRRule{Tuple{Matrix{P}, Matrix{P}}}()
end
# Note: the only difference between this signature and the one for the `rrule!!` above is the function itself.
function (rule::StatefulRRule)(
::CoDual{typeof(mul!)}, C::CoDual{Matrix{P}}, A::CoDual{Matrix{P}}, B::CoDual{Matrix{P}}
) where {P<:IEEEFloat}
# If we don't already have some state allocated, allocate some. After this, the
# remainder of the function is identical to the rrule!!.
if !isdefined(rule, :state)
rule.state = (copy(C.x), copy(C.dx))
end
# We can be sure that we have state in the rule at this point, so just make use of it.
C_copy = rule.state[1]
dC_copy = rule.state[2]
copy!(rule.state[1], C.x)
copy!(rule.state[2], C.dx)
# Run the forwards-pass.
mul!(C.x, A.x, B.x)
C.dx .= zero(P)
# The pullback can close over the `StatefulRule`.
function pb!!(::NoRData)
# Do the computations needed to increment tangents of A and B.
# code to increment A.dx
# code to increment B.dx
# Reset value of `C`.
copy!(C.x, C_copy)
copy!(C.dx, dC_copy)
return NoRData(), NoRData(), NoRData(), NoRData()
end
return C, pb!!
end
sig = Tuple{typeof(mul!), Matrix{Float64}, Matrix{Float64}, Matrix{Float64}};
stateful_rule = build_primitive_rrule(sig); Observe that the first time that a stateful rule is called it allocates the memory needed. On subsequent visits, it will just reuse the memory. As a result, we get the following timing results: julia> C, A, B = randn(16, 8), randn(16, 32), randn(32, 8);
julia> @btime mul!($C, $A, $B);
289.234 ns (0 allocations: 0 bytes)
julia> _C, _A, _B = zero_fcodual(C), zero_fcodual(A), zero_fcodual(B);
julia> @btime Mooncake.rrule!!(zero_fcodual(mul!), $_C, $_A, $_B)[2](NoRData());
395.835 ns (4 allocations: 2.22 KiB)
julia> @btime ($stateful_rule)(zero_fcodual(mul!), $_C, $_A, $_B)[2](NoRData());
358.533 ns (0 allocations: 0 bytes) Note that the implementation of the
These points are just to say that some thought is required regarding the caching structure, but hopefully they don't detract from the larger point: introducing a |
Note that another benefit of permitting stateful rules would be that we could properly implement ChainRules.jl's call-back-into-AD mechanism properly. Currently it's not really possible to do, because our AD requires carrying some state around in order to get optimal performance, but our |
Thanks for the explanation! Please post additional tips or ideas here because I'm interested to explore this issue more after my documentation PR. |
#437 made progress on this possible. What is needed now is a systematic audit of hand-written rules, to find examples where the primal does not allocate but the We need to do this systematically if we want to ensure that this problem permanently disappears, so I propose the following:
This should all be done in a single PR, as the changes to any individual rule are going to be quite minor -- it will also enable us to figure out what good abstractions look like for this. In particular, we should definitely add some abstraction to make it quite hard to accidentally fail to implement these stateful rules incorrectly by e.g. failing to account for the fact that the same rule can be called multiple times on the forwards-pass. Exactly what the correct solution will look like here is a little bit unclear to me, but we almost certainly do not want rule-writers to have to manage pushing / popping stacks themselves. edit: Here's a proposal for an abstraction that could work. The proposal: we define a generic callable struct struct StatefulRRule{sig, Tstate}
states::Stack{Tstate}
end
# This function implements Mooncake's rule interface.
function (rule::StatefulRRule)(args::CoDual...)
state = length(rule.stack) == rule.stack.position ? nothing : rule.stack[rule.stack.position]
primal_fdata_pair, pb!!, new_state = stateful_rrule!!(state, args...) # user will have to define a method of this function
push!(rule.stack, new_state)
function stateful_pb!!(dy)
pop!(rule.states)
return pb!!(dy)
end
return primal_fdata_pair, stateful_pb!!
end Additionally, we define a macro which is called as follows: @stateful_primitive context_type sig state_type which expands to @is_primitive context_type sig
function Mooncake.build_primitive_rrule(::Type{sig}) where {sig<:Tuple}
return StatefulRRule{sig,state_type}(Stack{state_type}())
end i.e. the macro declares the given signature to be a primitive via the usual mechanism, but also adds a method to A rule-writer is required to implement a method of
Most of the time, the So a rule-writer will typically write something like the following: @stateful_primitive context_type sig state_type
function stateful_rrule!!(state, args::CoDual...)
# Handle state.
if state === nothing
new_state = <allocate stuff>
else
<check state is of the correct size etc, and re-allocate if it is not>
end
# Run rule, exactly as before, modifying `new_state`.
return y, pb!!, new_state
end In terms of testing, we can tweak our Note that this design ensures that rule-writers never see any state management, they just get given whatever the current state is, are free to do with it whatever they like, and just have to return a new version of state at the end of the function call, in addition to what they usually return. |
Good stuff! For the latest proposed approach, would it be a problem if the developer manually allocates non-escaping temporary storage buffers via I personally do not directly call |
This is a great question. In short: I'm not sure. In principle I can't see a problem with this kind of thing, but we would need to give some specific examples some though in order to be sure. I've also not yet written rules for |
At present, if an
rrule!!
needs to allocate storage for some value it overwrites on the forwards pass, it must do so each time the rule is called. This is distinct from derived rules, which have state which is preserved between calls.There's no particular reason for things to be this way. This kind of functionality would be helpful, because it would make it possible to trade off some memory in exchange for runtime performance. Take, for example,
BLAS.gemm!(transA, transB, alpha, A, B, beta, C)
-- this rule has to allocate memory to hold whatever valueC
has upon entry, in order to restore it on the reverse-pass. Currently, it must allocate this memory each time the rule is called, but if we permit rules to be stateful we can just use the same heuristic that we use for the derived rules, and avoid de-allocating this memory -- subsequent calls to a rule would be fast.This change could probably be made non-breaking (
rrule!!
s would remain unchanged). We would just add a functionbuild_primitive_rrule
or something, which returnsrrule!!
by default, but which one can over-ride to return e.g. a custom callable struct instead.I first realised that this is probably going to be required while addressing #394 . The primal is type-stable and non-allocating, but Mooncake has a great deal of allocations. Now that a couple of performance bugs have been fixed (the generated IR was not type stable due to these bugs), the removal of the remaining allocations will require solving this issue.
The text was updated successfully, but these errors were encountered: