Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more Duplicated methods for Enzyme.jl support #2471

Open
wants to merge 35 commits into
base: master
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 25, 2024

This adds a method like gradient(f, ::Duplicated) which like train!(loss, model::Duplicated, data, opt) from #2446 uses the Duplicated type to signal that you want to use Enzyme not Zygote. It returns the gradient (for compatibility?) and mutates the Duplicated object.

  • To avoid piracy, this creates a new function Flux.gradient which by default calls Zygote.gradient. Unfortunately that's going to mean every using Flux, Zygote now produces ambiguities... so probably it should not be exported? Which means 0.15. Such ambiguities give a clear message, maybe that's OK? Maybe clearer than stopping exporting.

  • There's also withgradient but it doesn't allow you to return a tuple the way Zygote does, not yet. Now it does.

  • There's also a method of update! which either needs to move to Optimisers.jl, or again we need to... we should let Flux own the function? has moved to: Add Duplicated methods Optimisers.jl#192

  • Finally, @layer Chain defines a 1-argument Duplicated(c::Chain) method, so that you don't need to construct the dual by hand.

WIP, RFC?

Needs tests, and docs.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@ToucheSir
Copy link
Member

The docs failure here looks real, but I'm not sure why.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 6, 2024

Fixed the docs. I think we need to own update! too, for this not to be piracy.

But while I fix that, any objections to the interface? We seem to be merging things in a hurry now...

@mcabbott
Copy link
Member Author

mcabbott commented Nov 6, 2024

CUDA test failure is like this (and one more), why now?

Dropout Layer GPU grad test: Test Failed at /var/lib/buildkite-agent/builds/gpuci-15/julialang/flux-dot-jl/test/test_utils.jl:77
--
  | Expression: ≈(y_gpu, y, rtol = rtol, atol = atol)
  | Evaluated: 0.50719213f0 ≈ 0.5073142f0 (rtol=0.0001, atol=0.0001)

src/layers/macro.jl Outdated Show resolved Hide resolved
@ToucheSir
Copy link
Member

The interface looks reasonable to me and Flux owning gradient could make sense for moving off Zygote, but I don't love the careful internal plumbing required to make this work. I wonder if we can work around concerns about type piracy by moving the update! and trainable overloads to an Optimisers.jl extension.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2024

Can you clarify what you mean by "the careful internal plumbing"? If you mean how gradient works, it's trying to give friendly errors if you have loaded only EnzymeCore somehow. And also to not require Const; we could simplify by requiring either all Cons/Duplicated or none. But this doesn't seem so horrendous.

update! has a lot of methods because we keep old Flux.Optimise around (and aim for friendly errors). But I don't mind moving the new ones to Optimisers.jl. Flux has its own setup but using the wrong version matters less there.

Edit: now requires FluxML/Optimisers.jl#192

Copy link

codecov bot commented Nov 8, 2024

Codecov Report

Attention: Patch coverage is 78.49462% with 20 lines in your changes missing coverage. Please review.

Project coverage is 62.31%. Comparing base (e2b3f06) to head (a1380a0).

Files with missing lines Patch % Lines
src/deprecations.jl 0.00% 6 Missing ⚠️
src/gradient.jl 84.84% 5 Missing ⚠️
src/layers/macro.jl 71.42% 4 Missing ⚠️
src/train.jl 40.00% 3 Missing ⚠️
ext/FluxEnzymeExt/FluxEnzymeExt.jl 93.33% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2471       +/-   ##
===========================================
+ Coverage   33.54%   62.31%   +28.77%     
===========================================
  Files          31       33        +2     
  Lines        1881     1993      +112     
===========================================
+ Hits          631     1242      +611     
+ Misses       1250      751      -499     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@ToucheSir
Copy link
Member

The gradient code looked pretty straightforward, but there was also update!, train! and all of _macro_enzyme. I think the gradient changes are fine (though I'm surprised the code is type stable!), but would also be ok with limiting to Const + Duplicated if that makes things easier. It's easier to relax parts of the API than to tighten them up later.

The ideal path right now would be landing FluxML/Optimisers.jl#192 and then removing _macro_enzyme from here. I think that should be straightforward?

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2024

Re gradient, one quirk is that Enzyme has the rule that anything not Const is active:

julia> Enzyme.gradient(Reverse, dot, [1, 2.], [3, 4.])
([3.0, 4.0], [1.0, 2.0])

julia> Enzyme.gradient(Reverse, dot, [1, 2.], Const([3, 4.]))
([3.0, 4.0], nothing)

while this PR's rule is that, once one thing is Duplicated, everything else is constant:

julia> Flux.gradient(dot, [1, 2.], [3, 4.])  # Zygote
([3.0, 4.0], [1.0, 2.0])

julia> Flux.gradient(dot, [1, 2.], Duplicated([3, 4.], [NaN, NaN]))  # implicit Const
(nothing, [1.0, 2.0])

julia> Flux.gradient(dot, [1, 2.], Const([3, 4.]))
ERROR: ArgumentError: The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`.

IDK if that's too weird.

Edit, one more quirk... now fixed:

julia> Flux.gradient((x,y) -> sum((x .* y).^2), [1,2,3.], 4.0)  # Zygote
([32.0, 64.0, 96.0], 112.0)

julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), 4.0)  # implicit Const
([32.0, 64.0, 96.0], nothing)

julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), Active(4.0))
ERROR: ArgumentError: The method `gradient(f, xs...)` using Enzyme.jl does not support `Active`, only `Duplicated` and ``Const`.

julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), Duplicated(4.0, 0.0))  # now an error, 2636454
ERROR: `Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays.

@mcabbott
Copy link
Member Author

One question this PR opens is how how other AD might slot in. Mooncake.jl similarly prefers pre-allocated space for gradients, but uses different types. The obvious interface would be to have some container like so:

struct MoonPair{X,DX}; x::X; dx::DX; end
MoonPair(x) = MoonPair(x, Mooncake.zero_codual(x))

Flux could own this. That would be easy, in terms of e.g. defining show nicely.

Mooncake.jl could also own it, but then @layer cannot rely on it being defined. An extension would have to dispatch on something like Flux.gradient(f, args::Union{MoonPair, Const}...) and maybe that gets messy if the model is the 3rd argument.


forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
tape, result, shadow_result = forward(Const(f), args...)
reverse(Const(f), args..., _sensitivity(result), tape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just call autodiff here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adapted from the ReverseSplitWithPrimal doc example. My understanding was that it needs that in order to construct _sensitivity(result) after seeing what f returns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected return type of f. If it's a float it should be fine to just use autodiff dirrectly.

If not a float (as looks like below), it might be more efficient to do something like the following (since split mode will introduce overhead).

@inline asfloat(x) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
    or else a Tuple or NamedTuple whose first element is a real number.""")
    
@inline asfloat(x::Real) = x

@inline asfloat(x::Tuple) = asfloat(x[1])

@inline asfloat(x:: NamedTuple) = asfloat(x[1])

function return_asfloat(f, args...)
    return asfloat(@inline f(args...))
end

autodiff(Reverse, Const(return_asfloat), Active, Const(f), ...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is to pass things other than the loss out of the function:

julia> withgradient([1,2,4]) do x
            z = 1 ./ x
            sum(z), string("aux output: z = ", z)
         end
(val = (1.75, "aux output: z = [1.0, 0.5, 0.25]"), grad = ([-1.0, -0.25, -0.0625],))

Copy link
Member Author

@mcabbott mcabbott Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that trying to infer the return type of f would pay... when it is shown to return a float you could call ReverseWithPrimal. How big is the overhead of this ReverseSplitWithPrimal?

Edit, with examples from the other thread I can't measure it, <1% maybe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, well in that case maybe something like the following (assuming you know the return type)

function byref(out, f, args)
  res = f(args, ...)
  out[] = as_float(res)
  return res
end

dout = DuplicatedNoNeed(Ref(0.0), Ref(1.0))
autodiff(Reverse, byref, Const, Const(byref), dout, ....)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks possible. But where is it better? Is the overhead of ReverseSplitWithPrimal running it, or in generating code or something (which I didn't try to time)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

running it, though it depends on the code being differentiated how much will matter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I tried this Ref approach in ecef1f0 , but couldn't get all cases to work.

Of the cases in docstring text, complicated ones which return (loss, aux...) worked, but simple ones returning just the loss did not. The @show val always worked, but result = autodiff did not always.

@wsmoses
Copy link
Contributor

wsmoses commented Nov 11, 2024

Longer term it might be worth making sure that the design is flexible enough for Reactant integration as well (x/ref https://lux.csail.mit.edu/stable/manual/compiling_lux_models). Lux has already shown relatively big speedups at least on CPU

@avik-pal
Copy link
Member

Longer term it might be worth making sure that the design is flexible enough for Reactant integration as well (x/ref lux.csail.mit.edu/stable/manual/compiling_lux_models). Lux has already shown relatively big speedups at least on CPU

MLDataDevices has get_device_type which lets you get the device. See https://github.com/LuxDL/Lux.jl/blob/0be75045c37a51fc6369a28c9e8e893c1044089d/src/helpers/training.jl#L200-L206. If we get a AutoEnzyme + ReactantDevice, it switches to using Reactant. Also you need to ensure that compilation is done only once (https://github.com/LuxDL/Lux.jl/blob/0be75045c37a51fc6369a28c9e8e893c1044089d/ext/LuxReactantExt/training.jl#L37-L57) and then reused (https://github.com/LuxDL/Lux.jl/blob/0be75045c37a51fc6369a28c9e8e893c1044089d/ext/LuxReactantExt/training.jl#L59-L70)

src/train.jl Outdated Show resolved Hide resolved
Co-authored-by: Carlo Lucibello <[email protected]>
@CarloLucibello
Copy link
Member

sorry for this rebase mess

@mcabbott
Copy link
Member Author

What does this break?

@CarloLucibello
Copy link
Member

whoever does using Zygote, Flux and calls gradient will find his code broken, right?

@mcabbott
Copy link
Member Author

Ah that is true. Broken with a helpful message but I agree.

@CarloLucibello
Copy link
Member

We should specialize Flux.state for Duplicated to return only the model's parameters.

@mcabbott
Copy link
Member Author

Certainly it should not save the gradient.

I guess there's an open question here whether gradient/setup/state/loadmodel! should deal with an outer (;val=..., dval=nothing), or not. At the present state of the PR, gradient/setup choose not. Which is done because it ensures gradient(f, m) == gradient(f, Duplicated(m)). Maybe it is OK for state/loadmodel! to follow that?

@CarloLucibello
Copy link
Member

Maybe it is OK for state/loadmodel! to follow that?

+1

Comment on lines 17 to 24
function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && make_zero!(x.dval)
_check_mutable(x)
end
Enzyme.autodiff(Reverse, f, Active, args...)
map(_grad_or_nothing, args)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some failure cases:

julia> using Flux, Enzyme

# Zygote

julia> Flux.gradient(sum  LayerNorm(3), zeros(3))
([0.0, 0.0, 0.0],)

julia> Flux.gradient(|>, zeros(3), sum  LayerNorm(3))
([0.0, 0.0, 0.0], (outer = nothing, inner == nothing, diag = (scale = Float32[0.0, 0.0, 0.0], bias = Float32[1.0, 1.0, 1.0], σ = nothing), ϵ = 0.0, size = nothing, affine = nothing)))

# Enzyme

julia> Flux.gradient(sum  LayerNorm(3), Duplicated(zeros(3), zeros(3))
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for:   store {} addrspace(10)* %.fca.0.0.0.extract, {} addrspace(10)* addrspace(10)* %.innerparm.sroa.0.0..sroa_cast, align 8, !dbg !15, !alias.scope !17, !noalias !21 const val:   %.fca.0.0.0.extract = extractvalue { { { {} addrspace(10)*, {} addrspace(10)* }, float, [1 x i64], i8 } } %0, 0, 0, 0, !dbg !8
 value=Unknown object of type Vector{Float32}
 llvalue=  %.fca.0.0.0.extract = extractvalue { { { {} addrspace(10)*, {} addrspace(10)* }, float, [1 x i64], i8 } } %0, 0, 0, 0, !dbg !8

Stacktrace:
 [1] ComposedFunction
   @ ./operators.jl:1041
 [2] ComposedFunction
   @ ./operators.jl:0

Stacktrace:
  [1] ComposedFunction
    @ ./operators.jl:1041 [inlined]
  [2] ComposedFunction
    @ ./operators.jl:0 [inlined]
  [3] augmented_julia_ComposedFunction_18168_inner_1wrap
    @ ./operators.jl:0
  [4] macro expansion
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8398 [inlined]
  [5] enzyme_call
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7950 [inlined]
  [6] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7787 [inlined]
  [7] autodiff(rmode::ReverseMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:384
  [8] _enzyme_gradient(f::Function, args::Duplicated{Vector{Float64}}; zero::Bool)
    @ FluxEnzymeExt ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:22
  [9] _enzyme_gradient
    @ ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:17 [inlined]
 [10] gradient(f::Function, args::Duplicated{Vector{Float64}})
    @ Flux ~/.julia/dev/Flux/src/gradient.jl:122

julia> Duplicated(c) = Duplicated(c, Enzyme.make_zero(c));

julia> Flux.gradient(|>, Duplicated(zeros(3)), Duplicated(sum  LayerNorm(3)))
ERROR: setfield!: immutable struct of type ComposedFunction cannot be changed
Stacktrace:
 [1] make_zero!
   @ ~/.julia/packages/Enzyme/RTS5U/src/make_zero.jl:529 [inlined]
 [2] make_zero!(prev::ComposedFunction{typeof(sum), LayerNorm{typeof(identity), Flux.Scale{…}, Float32, 1}})
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/make_zero.jl:504
 [3] _enzyme_gradient(::Function, ::Duplicated{Vector{Float64}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
   @ FluxEnzymeExt ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:19

julia> Flux.gradient(|>, Duplicated(zeros(3)), Const(sum  LayerNorm(3)))
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for:   store {} addrspace(10)* %.fca.0.0.0.extract, {} addrspace(10)* addrspace(10)* %.innerparm.sroa.0.0..sroa_cast, align 8, !dbg !17, !alias.scope !20, !noalias !24 const val:   %.fca.0.0.0.extract = extractvalue { { { {} addrspace(10)*, {} addrspace(10)* }, float, [1 x i64], i8 } } %1, 0, 0, 0, !dbg !8
 value=Unknown object of type Vector{Float32}
 llvalue=  %.fca.0.0.0.extract = extractvalue { { { {} addrspace(10)*, {} addrspace(10)* }, float, [1 x i64], i8 } } %1, 0, 0, 0, !dbg !8
 
 # to see code?
 
 julia> @less LayerNorm(3)(zeros(3))

julia> @less LayerNorm(3).diag(zeros(3))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are solved by 5c1650f

Remaining related bugs in withgradient:

julia> Flux.withgradient(|>, Duplicated(zeros(3)), Duplicated(sum  LayerNorm(3)))
ERROR: AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
Stacktrace:
  [1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{…}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:4433
  [2] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:4169
  [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7338
  [4] codegen
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:6146 [inlined]
  [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468
  [6] _thunk
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468 [inlined]
  [7] cached_compilation
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8509 [inlined]
  [8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8641
  [9] #s2103#19072
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8778 [inlined]
 [10] 
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] autodiff_thunk(::EnzymeCore.ReverseModeSplit{…}, ::Type{…}, ::Type{…}, ::Type{…}, ::Type{…})
    @ Enzyme ~/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:968
 [13] _enzyme_withgradient(::Function, ::Duplicated{Vector{Float64}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
    @ FluxEnzymeExt ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:72
 [14] _enzyme_withgradient
    @ ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:177 [inlined]
 [15] withgradient(::Function, ::Duplicated{Vector{Float64}}, ::Duplicated{ComposedFunction{typeof(sum), LayerNorm{…}}})
    @ Flux ~/.julia/dev/Flux/src/gradient.jl:226
 [16] top-level scope
    @ REPL[112]:1
Some type information was truncated. Use `show(err)` to see complete types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this I think is a good demonstration why using a direct autodiff instead of thunk (and say return the lossfn as a float and return any extra info as a Const(Ref()) or vice versa) would be useful

Copy link
Member Author

@mcabbott mcabbott Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah maybe I got that working at last, 016cfe6:

julia> Flux.withgradient(|>, Duplicated(zeros(3)), Duplicated(sum ∘ LayerNorm(3)))
(val = 0.0, grad = ([0.0, 0.0, 0.0], (outer = nothing, inner = (λ = nothing, diag = (scale = Float32[0.0, 0.0, 0.0], bias = Float32[1.0, 1.0, 1.0], σ = nothing), ϵ = nothing, size = (nothing,), affine = nothing))))

That's the "or vice versa" path, I think -- write the loss into Ref(0.0) & return everything else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants