Skip to content

Commit

Permalink
Proper Circular Reference Handling (#416)
Browse files Browse the repository at this point in the history
* Progress

* Bump Julia version requirement

* Some fixes

* Fix up errors

* Get tests passing

* Cache -> MaybeCache

* Fix more tests

* Fix GPU problems

* Fix increment correctness problems on LTS

* Enable more tests

* Formatting

* Fix missing implementations

* Fix alloc check

* Fix tuple increment inference

* More performance fixes

* Fix typo

* Bump patch version

* Simplify build_tangent implementation

* Simplify build_tangent implementation

* Remove redundant LoC
  • Loading branch information
willtebbutt authored Feb 4, 2025
1 parent 65ef3c0 commit faf19cb
Show file tree
Hide file tree
Showing 14 changed files with 700 additions and 358 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.80"
version = "0.4.81"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -66,7 +66,7 @@ Setfield = "1"
SpecialFunctions = "2"
StableRNGs = "1"
Test = "1"
julia = "~1.10, 1.11.2"
julia = "~1.10, 1.11.3"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Expand Down
4 changes: 3 additions & 1 deletion ext/MooncakeAllocCheckExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module MooncakeAllocCheckExt
using AllocCheck, Mooncake
import Mooncake.TestUtils: check_allocs, Shim

@check_allocs check_allocs(::Shim, f::F, x...) where {F} = f(x...)
@check_allocs check_allocs(::Shim, f::F, x) where {F} = f(x)
@check_allocs check_allocs(::Shim, f::F, x, y) where {F} = f(x, y)
@check_allocs check_allocs(::Shim, f::F, x, y, z) where {F} = f(x, y, z)

end
95 changes: 68 additions & 27 deletions ext/MooncakeCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,90 @@ import Mooncake:
@is_primitive,
tangent_type,
tangent,
zero_tangent,
randn_tangent,
increment!!,
_set_to_zero!!,
_add_to_primal,
_diff,
_dot,
_scale,
zero_tangent_internal,
randn_tangent_internal,
increment_internal!!,
set_to_zero_internal!!,
_add_to_primal_internal,
_diff_internal,
_dot_internal,
_scale_internal,
TestUtils,
CoDual,
NoPullback,
to_cr_tangent,
increment_and_get_rdata!
increment_and_get_rdata!,
MaybeCache,
IncCache

import Mooncake.TestUtils: populate_address_map!, AddressMap, __increment_should_allocate
import Mooncake.TestUtils:
populate_address_map_internal, AddressMap, __increment_should_allocate

const CuFloatArray = CuArray{<:IEEEFloat}

# Tell Mooncake.jl how to handle CuArrays.

Mooncake.@tt_effects tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P
zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x)
function randn_tangent(rng::AbstractRNG, x::CuArray{Float32})
return cu(randn(rng, Float32, size(x)...))
Mooncake.@tt_effects tangent_type(::Type{P}) where {P<:CuFloatArray} = P
function zero_tangent_internal(x::CuFloatArray, stackdict::Any)
haskey(stackdict, x) && return stackdict[x]::tangent_type(typeof(x))
t = zero(x)
stackdict[x] = t
return t
end
function randn_tangent_internal(rng::AbstractRNG, x::CuFloatArray, stackdict::Any)
haskey(stackdict, x) && return stackdict[x]::tangent_type(typeof(x))
t = cu(randn(rng, Float32, size(x)...))
stackdict[x] = t
return t
end
function TestUtils.has_equal_data_internal(
x::P, y::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool}
) where {P<:CuArray{<:IEEEFloat}}
) where {P<:CuFloatArray}
return isapprox(x, y)
end
increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y
__increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true
_set_to_zero!!(::Mooncake.IncCache, x::CuArray{<:IEEEFloat}) = x .= 0
_add_to_primal(x::P, y::P, ::Bool) where {P<:CuArray{<:IEEEFloat}} = x + y
_diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y
_dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y))
_scale(x::Float64, y::P) where {T<:IEEEFloat,P<:CuArray{T}} = T(x) * y
function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray)
function increment_internal!!(c::IncCache, x::P, y::P) where {P<:CuFloatArray}
(x === y || haskey(c, x)) && return x
c[x] = true
x .+= y
return x
end
__increment_should_allocate(::Type{<:CuFloatArray}) = true
set_to_zero_internal!!(::Mooncake.IncCache, x::CuFloatArray) = x .= 0
function _add_to_primal_internal(
c::MaybeCache, x::P, y::P, unsafe::Bool
) where {P<:CuFloatArray}
key = (x, y, unsafe)
haskey(c, key) && return c[key]::P
x′ = x + y
c[(x, y, unsafe)] = x′
return x′
end
function _diff_internal(c::MaybeCache, x::P, y::P) where {P<:CuFloatArray}
key = (x, y)
haskey(c, key) && return c[key]::tangent_type(P)
t = x - y
c[key] = t
return t
end
function _dot_internal(c::MaybeCache, x::P, y::P) where {P<:CuFloatArray}
key = (x, y)
haskey(c, key) && return c[key]::Float64
return Float64(dot(x, y))
end
function _scale_internal(c::MaybeCache, x::Float64, y::P) where {T<:IEEEFloat,P<:CuArray{T}}
haskey(c, y) && return c[y]::P
t′ = T(x) * y
c[y] = t′
return t′
end
function populate_address_map_internal(m::AddressMap, p::CuArray, t::CuArray)
k = pointer_from_objref(p)
v = pointer_from_objref(t)
haskey(m, k) && (@assert m[k] == v)
m[k] = v
return m
end
function Mooncake._verify_fdata_value(p::CuArray, f::CuArray)
function Mooncake.__verify_fdata_value(::IdDict{Any,Nothing}, p::CuArray, f::CuArray)
if size(p) != size(f)
throw(InvalidFDataException("p has size $(size(p)) but f has size $(size(f))"))
end
Expand All @@ -62,8 +103,8 @@ end
tangent_type(::Type{P}, ::Type{NoRData}) where {P<:CuArray} = P
tangent(p::CuArray, ::NoRData) = p

to_cr_tangent(x::CuArray{<:IEEEFloat}) = x
function increment_and_get_rdata!(f::T, ::NoRData, t::T) where {T<:CuArray{<:IEEEFloat}}
to_cr_tangent(x::CuFloatArray) = x
function increment_and_get_rdata!(f::T, ::NoRData, t::T) where {T<:CuFloatArray}
f .+= t
return NoRData()
end
Expand All @@ -73,7 +114,7 @@ end
@is_primitive(MinimalCtx, Tuple{Type{<:CuArray},UndefInitializer,Vararg{Int,N}} where {N},)
function rrule!!(
p::CoDual{Type{P}}, init::CoDual{UndefInitializer}, dims::CoDual{Int}...
) where {P<:CuArray{<:Base.IEEEFloat}}
) where {P<:CuFloatArray}
_dims = map(primal, dims)
return CoDual(P(undef, _dims), P(undef, _dims)), NoPullback(p, init, dims...)
end
Expand Down
39 changes: 26 additions & 13 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct NoFData end

Base.copy(::NoFData) = NoFData()

increment!!(::NoFData, ::NoFData) = NoFData()
increment_internal!!(::IncCache, ::NoFData, ::NoFData) = NoFData()

"""
FData(data::NamedTuple)
Expand All @@ -26,7 +26,9 @@ _copy(x::P) where {P<:FData} = P(_copy(x.data))

fields_type(::Type{FData{T}}) where {T<:NamedTuple} = T

increment!!(x::F, y::F) where {F<:FData} = F(tuple_map(increment!!, x.data, y.data))
function increment_internal!!(c::IncCache, x::F, y::F) where {F<:FData}
return F(tuple_map((a, b) -> increment_internal!!(c, a, b), x.data, y.data))
end

"""
fdata_type(T)
Expand Down Expand Up @@ -301,16 +303,18 @@ condition, just that they rule out a variety of common problems.
Put differently, we cannot prove that `f` is valid fdata, only that it is not obviously
invalid.
"""
function verify_fdata_value(p, f)::Nothing
verify_fdata_value(p, f)::Nothing = _verify_fdata_value(IdDict{Any,Nothing}(), p, f)

function _verify_fdata_value(c::IdDict{Any,Nothing}, p, f)::Nothing
verify_fdata_type(_typeof(p), typeof(f))
return _verify_fdata_value(p, f)
return __verify_fdata_value(c, p, f)
end

_verify_fdata_value(::IEEEFloat, ::NoFData) = nothing
__verify_fdata_value(::IdDict{Any,Nothing}, ::IEEEFloat, ::NoFData) = nothing

_verify_fdata_value(::Ptr, ::Ptr) = nothing
__verify_fdata_value(::IdDict{Any,Nothing}, ::Ptr, ::Ptr) = nothing

function _verify_fdata_value(p::Array, f::Array)
function __verify_fdata_value(c::IdDict{Any,Nothing}, p::Array, f::Array)
if size(p) != size(f)
throw(InvalidFDataException("p has size $(size(p)) but f has size $(size(f))"))
end
Expand All @@ -323,8 +327,11 @@ function _verify_fdata_value(p::Array, f::Array)
# correct separately.
for n in eachindex(p)
if isassigned(p, n)
_p = p[n]
ismutable(_p) && haskey(c, _p) && continue
ismutable(_p) && !haskey(c, _p) && setindex!(c, nothing, _p)
t = f[n]
verify_fdata_value(p[n], fdata(t))
_verify_fdata_value(c, p[n], fdata(t))
verify_rdata_value(p[n], rdata(t))
end
end
Expand All @@ -338,7 +345,7 @@ _get_fdata_field(f::Tuple, name) = getfield(f, name)
_get_fdata_field(f::FData, name) = val(getfield(f.data, name))
_get_fdata_field(f::MutableTangent, name) = fdata(val(getfield(f.fields, name)))

function _verify_fdata_value(p, f)
function __verify_fdata_value(c::IdDict{Any,Nothing}, p, f)

# If f is a NoFData then there are no checks needed, because we have already verified
# that NoFData is the correct type for fdata for p, and NoFData is a singleton type.
Expand All @@ -355,8 +362,10 @@ function _verify_fdata_value(p, f)
for name in fieldnames(P)
if isdefined(p, name)
_p = getfield(p, name)
ismutable(_p) && haskey(c, _p) && continue
ismutable(_p) && !haskey(c, _p) && setindex!(c, nothing, _p)
t = _get_fdata_field(f, name)
verify_fdata_value(_p, t)
_verify_fdata_value(c, _p, t)
if f isa MutableTangent
verify_rdata_value(_p, rdata(val(getfield(f.fields, name))))
end
Expand All @@ -375,7 +384,7 @@ struct NoRData end

Base.copy(::NoRData) = NoRData()

@inline increment!!(::NoRData, ::NoRData) = NoRData()
@inline increment_internal!!(::IncCache, ::NoRData, ::NoRData) = NoRData()

@inline increment_field!!(::NoRData, y, ::Val) = NoRData()

Expand All @@ -387,7 +396,9 @@ _copy(x::P) where {P<:RData} = P(_copy(x.data))

fields_type(::Type{RData{T}}) where {T<:NamedTuple} = T

@inline increment!!(x::RData{T}, y::RData{T}) where {T} = RData(increment!!(x.data, y.data))
@inline function increment_internal!!(c::IncCache, x::RData{T}, y::RData{T}) where {T}
return RData(increment_internal!!(c, x.data, y.data))
end

@inline function increment_field!!(x::RData{T}, y, ::Val{f}) where {T,f}
y isa NoRData && return x
Expand Down Expand Up @@ -950,7 +961,9 @@ end
Increment the rdata component of tangent `t` by `r`, and return the updated tangent.
Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.
"""
increment_rdata!!(t::T, r) where {T} = tangent(fdata(t), increment!!(rdata(t), r))::T
function increment_rdata!!(t::T, r) where {T}
return tangent(fdata(t), increment_internal!!(NoCache(), rdata(t), r))::T
end

"""
zero_tangent(primal, fdata)
Expand Down
41 changes: 28 additions & 13 deletions src/rrules/array_legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,54 @@ function randn_tangent_internal(
return _map_if_assigned!(x -> randn_tangent_internal(rng, x, stackdict), dx, x)
end

function increment!!(x::T, y::T) where {P,N,T<:Array{P,N}}
return x === y ? x : _map_if_assigned!(increment!!, x, x, y)
function increment_internal!!(c::IncCache, x::T, y::T) where {P,N,T<:Array{P,N}}
(haskey(c, x) || x === y) && return x
c[x] = true
return _map_if_assigned!((x, y) -> increment_internal!!(c, x, y), x, x, y)
end

function _set_to_zero!!(c::IncCache, x::Array)
function set_to_zero_internal!!(c::IncCache, x::Array)
haskey(c, x) && return x
c[x] = false
return _map_if_assigned!(Base.Fix1(_set_to_zero!!, c), x, x)
return _map_if_assigned!(Base.Fix1(set_to_zero_internal!!, c), x, x)
end

function _scale(a::Float64, t::Array{T,N}) where {T,N}
function _scale_internal(c::MaybeCache, a::Float64, t::Array{T,N}) where {T,N}
haskey(c, t) && return c[t]::Array{T,N}
t′ = Array{T,N}(undef, size(t)...)
return _map_if_assigned!(Base.Fix1(_scale, a), t′, t)
c[t] = t′
return _map_if_assigned!(t -> _scale_internal(c, a, t), t′, t)
end

function _dot(t::T, s::T) where {T<:Array}
isbitstype(T) && return sum(_map(_dot, t, s))
function _dot_internal(c::MaybeCache, t::T, s::T) where {T<:Array}
key = (t, s)
haskey(c, key) && return c[key]::Float64
c[key] = 0.0
isbitstype(T) && return sum(_map((t, s) -> _dot_internal(c, t, s), t, s))
return sum(
_map(eachindex(t)) do n
(isassigned(t, n) && isassigned(s, n)) ? _dot(t[n], s[n]) : 0.0
(isassigned(t, n) && isassigned(s, n)) ? _dot_internal(c, t[n], s[n]) : 0.0
end;
init=0.0,
)
end

function _add_to_primal(x::Array{P,N}, t::Array{<:Any,N}, unsafe::Bool) where {P,N}
function _add_to_primal_internal(
c::MaybeCache, x::Array{P,N}, t::Array{<:Any,N}, unsafe::Bool
) where {P,N}
key = (x, t, unsafe)
haskey(c, key) && return c[key]::Array{P,N}
x′ = Array{P,N}(undef, size(x)...)
return _map_if_assigned!((x, t) -> _add_to_primal(x, t, unsafe), x′, x, t)
c[key] = x′
return _map_if_assigned!((x, t) -> _add_to_primal_internal(c, x, t, unsafe), x′, x, t)
end

function _diff(p::P, q::P) where {V,N,P<:Array{V,N}}
function _diff_internal(c::MaybeCache, p::P, q::P) where {V,N,P<:Array{V,N}}
key = (p, q)
haskey(c, key) && return c[key]::tangent_type(P)
t = Array{tangent_type(V),N}(undef, size(p))
return _map_if_assigned!(_diff, t, p, q)
c[key] = t
return _map_if_assigned!((p, q) -> _diff_internal(c, p, q), t, p, q)
end

@zero_adjoint MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),Vararg} where {T,N}
Expand Down
Loading

2 comments on commit faf19cb

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/124314

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.81 -m "<description of version>" faf19cb8ecfa1886a0bcf573794b0cedc363f424
git push origin v0.4.81

Please sign in to comment.