Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote compat is lacking #232

Open
torfjelde opened this issue Jan 15, 2023 · 11 comments
Open

Zygote compat is lacking #232

torfjelde opened this issue Jan 15, 2023 · 11 comments

Comments

@torfjelde
Copy link

Zygote doesn't interact too nicely with LazyArrays.jl it seems, e.g.:

julia> f(x) = sum(BroadcastArray(exp, x))
f (generic function with 1 method)

julia> Zygote.gradient(f, randn(10))
ERROR: type Array has no field f
Stacktrace:
  [1] adjoint
    @ ~/.julia/packages/Zygote/AS0Go/src/lib/lib.jl:229 [inlined]
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [3] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:50 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(LazyArrays.call), ::ArrayLayouts.DenseColumnMajor, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
  [5] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:52 [inlined]
  [6] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:82 [inlined]
  [7] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:57 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::Type{BroadcastArray}, ::typeof(exp), ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./REPL[48]:1 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::typeof(f), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
 [11] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
 [12] pullback
    @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
 [14] top-level scope
    @ REPL[50]:1

julia> g(x) = sum(LazyArray(@~ exp.(x)))
g (generic function with 1 method)

julia> Zygote.gradient(g, randn(10))
ERROR: MethodError: no method matching LazyArray(::Vector{Float64})
Closest candidates are:
  LazyArray(::Base.Broadcast.Broadcasted) at ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:35
  LazyArray(::Applied) at ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:193
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(ctx::Zygote.Context{false}, f::Type{LazyArray}, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:9
 [3] _pullback
   @ ./REPL[53]:1 [inlined]
 [4] _pullback(ctx::Zygote.Context{false}, f::typeof(g), args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
 [5] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
 [6] pullback
   @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
 [7] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
 [8] top-level scope
   @ REPL[54]:1

The first error can be "fixed" (I'm not entirely certain if this is the right way to go about it) by defining a chain rule:

julia> using ChainRulesCore

julia> function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...)
           return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...)
       end

julia> Zygote.refresh()

julia> Zygote.gradient(f, randn(10))
([0.24117702568683322, 2.478340448616497, 2.433266795642693, 1.6163793920298133, 1.8859252985478665, 3.9539878829654223, 1.2578105524502685, 0.48545348574922, 0.8710494256114425, 3.0853524634917076],)

Maybe the rest can be addressed this way too.

Are rules from CRC something that would be welcomed?

@dlfivefifty
Copy link
Member

Hmm.... that's a good question.... I'm usually hesitant to add "*Core.jl" dependencies because a lot of them are of questionable usage but ChainRulesCore.jl might be an exception.

One alternative solution is to make a glue package a la FastTransformsForwardDiff. (I'm wondering whether that should have been FastTransformsChainRulesCore.jl...)

@torfjelde
Copy link
Author

Either alternative is okay with me:)

@torfjelde
Copy link
Author

You just say which alternative you prefer, and I can try to contribute towards it.

@dlfivefifty
Copy link
Member

Let's put it in a separate package for now so we can work out the kinks. We can always merge it back here (in the event there's a good reason to have it).

@devmotion
Copy link

It seems this is a good use case for weak deps. Some packages already started moving ChainRules definition to weak deps. The definitions would be loaded only on Julia >= 1.9 (if you don't want to uae Requires on older Julia versions) but I think it would be the better long-term solution.

@torfjelde
Copy link
Author

It woul suck if we'd have to wait until Julia 1.9 before we could make use of this though 😕

@devmotion
Copy link

I assume it already works with the beta version, so I think you can already use it without compiling julia.

@dlfivefifty
Copy link
Member

Can we do a separate package that works now, and becomes a weak dependency in Julia v1.9?

@devmotion
Copy link

If a weak dependency is loaded, an extension (usually a single file) in the ext subfolder is loaded (and precompiled, in contrast to the Requires hacks!). AFAIK there are no separate packages involved or loaded in the extension apart from the weak dependency and the package + hard dependencies, and making the glue package a hard dependency would defeat its purpose. An example is shown in this PR: JuliaMath/ChangesOfVariables.jl#12

@dlfivefifty
Copy link
Member

I see. I think a weak dependency hear would be fine. I would suggest forgetting the separate project and just requiring v1.9

@oschulz
Copy link

oschulz commented Jan 26, 2023

We use weak deps for ChangesOfVariables.jl now, and it works like a charm on Julia v1.9:

julia> @time_imports import ChangesOfVariables
      0.6 ms  ChangesOfVariables

julia> @time_imports import ChainRulesCore
      0.1 ms  Compat
     58.9 ms  ChainRulesCore
      0.4 ms  ChangesOfVariables → ChainRulesCoreExt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants