Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch Flux._isleaf for abstract arrays with bitstype elements #2436

Merged
merged 3 commits into from
May 9, 2024

Conversation

jondeuce
Copy link
Contributor

@jondeuce jondeuce commented May 6, 2024

Fixes #2432:

julia> x = [1.0];
julia> Flux._isleaf(x') # should be false
true # false with this PR

julia> m = (; a = x, b = x'); # tied weights
julia> mcopy = fmap(copy, m; exclude = Flux._isleaf); # copy should preserve tie
julia> mcopy.a === mcopy.b' # should be true
false # true with this PR

On master we have Flux._isleaf(x) = _isbitsarray(x) || Functors.isleaf(x), but _isbitsarray(x) returns true for any AbstractArray{T} where isbitstype(T) == true, and so we get _isbitsarray([1.0]') == true and therefore Flux._isleaf([1.0]') == true. In the referenced issue, this breaks parameter sharing between a set of weights and their transpose when a model is moved to the gpu.

The fundamental issue is that AFAICT there is not a good way to extend Functors.isleaf outside of Functors.jl for abstract types which may contain children of the same abstract type. For example, here Transpose <: AbstractArray contains a parent::AbstractArray field, and so Functors.jl must overload Functor.functor(::Transpose) otherwise it would not be recursed into. But of course this can't be done similarly outside of Functors.jl without type piracy (hence Flux._isleaf).

So in order to:

  1. Avoid copying all the overloaded Functors.functor(::AbstractArray) methods defined here into Flux, and
  2. Maintain the desired behaviour of treating AbstractArrays with bitstype elements as leaves,

I've removed _isbitsarray in favour of defining _isleaf methods directly, and special-cased Flux._isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false to match Functors.jl. The other option is to do this in Functors.jl directly, but I'm not sure if treating all bitstype arrays as leaves is desirable in general (also it's probably breaking?).

@ToucheSir
Copy link
Member

Thanks for the PR! I'm a little late responding to the main issue so I'll do so here. It looks like there are still GPU-related tests failing and it's not immediately clear why, so more tweaking may be required.

Stepping back a bit, you were right to point out in #2432 that the call stack/responsibilities of various libraries here is not super clear. Let me lay out what I think a more ideal design could be, and let's see how much of that can inform this PR.

One possibility to bridge the disparate views of Flux, Functors and Adapt (which wasn't mentioned in #2432 but is a very important third player here) is to move some logic to the latter. My proposal would be to share the cache fmap uses to detect shared parameters with Flux's Adapt adapters. The motivation here is that Adapt has much better coverage of array wrapper types than Functors, including how to recurse into them. By continuing to consider wrapped arrays as "XPU-able" in Flux, we can decouple this part of the conversion logic and delegate it to Adapt instead of having to re-implement it.

@jondeuce
Copy link
Contributor Author

jondeuce commented May 8, 2024

It looks like there are still GPU-related tests failing and it's not immediately clear why, so more tweaking may be required.

I think the only failure is

@test gradient(x -> sum(cpu(x)), ca')[1] isa CuArray

which fails because previously gradient(x -> sum(cpu(x)), ca')[1] returned an all-ones CuMatrix{Float32}, but now returns an all-ones Adjoint{Float32, CuMatrix{Float32}}, which is a slight improvement. I'll update the test and see if anything else fails.

One possibility to bridge the disparate views of Flux, Functors and Adapt (which wasn't mentioned in #2432 but is a very important third player here) is to move some logic to the latter. My proposal would be to share the cache fmap uses to detect shared parameters with Flux's Adapt adapters. The motivation here is that Adapt has much better coverage of array wrapper types than Functors, including how to recurse into them. By continuing to consider wrapped arrays as "XPU-able" in Flux, we can decouple this part of the conversion logic and delegate it to Adapt instead of having to re-implement it.

All of this sounds fantastic, and of course much more robust than this PR, which is really just meant as a stopgap to patch adjoint/transpose/etc. until a more complete fix is implemented.

Copy link

codecov bot commented May 8, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 74.10%. Comparing base (c442f0c) to head (37bec19).

Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2436       +/-   ##
===========================================
+ Coverage   46.37%   74.10%   +27.72%     
===========================================
  Files          32       32               
  Lines        1876     1923       +47     
===========================================
+ Hits          870     1425      +555     
+ Misses       1006      498      -508     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@CarloLucibello CarloLucibello merged commit 26c9acf into FluxML:master May 9, 2024
24 of 29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dense layers with shared parameters
3 participants