-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Omnibus PR, including switch to explicit style differentiation #251
Conversation
This PR is above average in complexity. This means a review is particularly important but it's also going to be more work than usual. @pat-alt Do you have any interest and time to review over the next 3 weeks, say? My apologies in advance for slightly messy commits. I temporarily lost access to a GPU for local testing and was shooting blind for a while. |
for i in 1:n_batches | ||
batch_loss, gs = Flux.withgradient(chain) do m | ||
yhat = m(X[i]) | ||
loss(yhat, y[i]) + sum(penalty, Optimisers.trainables(m))/n_batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pat-alt Just FYI: Optimisers.jl added trainables
which offers this solution to refactor along your original line. It superficially seems like Flux.params
, but doesn't seem to suffer the same problems we were seeing (if you can remember back that far 😉)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
So from a user-persective, I'm currently unclear about a few things (perhaps I've just not looked carefully enough!):
- Is the
regularised_optimiser
method user-facing? Where/when is it used here in this specific unit test? - Would it make sense to ship the
Penalizer
?
@pat-alt Sorry to re-ping, but I'm not sure who else to ask for a review here. If possible, hoping for a merge in the next 3 weeks. Even a cursory look, would be much appreciated! |
Hi! I'll try to have a look as soon as I can (probably on the weekend or next week). |
If it is helpful I can also give reviewing this PR a go. Probably won't have time the next few days but early next week should be feasible. |
Let's see if @pat-alt is able to finds some time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mostly have a few comments about 1) deprecation and 2) clarification of how regularization is actually implemented on the user-end.
Otherwise, this looks good to me!
@@ -14,13 +14,12 @@ | |||
abstract type Builder <: MLJModelInterface.MLJType end | |||
|
|||
""" | |||
Linear(; σ=Flux.relu, rng=Random.GLOBAL_RNG) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth deprecating this?
end | ||
|
||
|
||
""" | ||
fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this, too, might be worth deprecating. If I understand this correctly, existing extensions that overload MLJFlux.fit!
like here won't work anymore? As in, they should now be overloading the train
function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So just to double-check @ablaom, fit!
will become train
and train!
will become train_epoch
?
@@ -43,7 +49,26 @@ end | |||
const ERR_BUILDER = | |||
"Builder does not appear to build an architecture compatible with supplied data. " | |||
|
|||
true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng | |||
true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just of curiosity, why this change?
# ensure penalization over an epoch does not scale with the choice of batch size; see | ||
# https://github.com/FluxML/MLJFlux.jl/issues/213. | ||
|
||
function regularized_optimiser(model, nbatches) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if I read this correctly, it is not possible to specify either L1 or L2? The regularized optimiser will always apply both?
for i in 1:n_batches | ||
batch_loss, gs = Flux.withgradient(chain) do m | ||
yhat = m(X[i]) | ||
loss(yhat, y[i]) + sum(penalty, Optimisers.trainables(m))/n_batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
So from a user-persective, I'm currently unclear about a few things (perhaps I've just not looked carefully enough!):
- Is the
regularised_optimiser
method user-facing? Where/when is it used here in this specific unit test? - Would it make sense to ship the
Penalizer
?
Thanks @pat-alt for your review. Much appreciated. I've made a few tweaks in response to the comments. |
Thanks @ablaom, will have another look today. @Rockdeldiablo for reference. See in particular the redefinition of |
@pat-alt How are we doing? Happy with the changes? |
Sorry, yes, I missed the thumbs up :) Thanks for clarifying! |
1bd58dd adds deprecations for the |
Thanks @pat-alt for your review. 🙏🏾 |
This PR combines a number of changes, which for technical reasons could not be easily split up. The most important change, anticipating a Flux 0.15 breakage, is the switch to explicit differentiation; so this PR replaces #230. Shout out to @pat-alt for outlining a solution there.
Closes #221.
To do:
Replace implicit style parameter updates with explicit style parameter updates, in
line with planned Zygote/Flux deprecations.
Refactor code to use optimisers from Optimisers.jl with
setup/update
pattern inplace of
update!
pattern. Also, rename private methodstrain!
->train_epoch
and
fit!
->train
to reflect new non-mutating behaviour. This possibly breakssome "custom" models that have chosen to overload these technically private methods.
(RNG changes.) Change the default value of the model field
rng
fromRandom.GLOBAL_RNG
toRandom.default_rng()
. Change the seeded RNG, obtained byspecifying an integer value for
rng
, fromMersenneTwister
toXoshiro
.Update the
Short
builder so that therng
argument ofbuild(::Short, rng, ...)
is passed on to the
Dropout
layer, as these layers now support this on a GPU, atleast for
rng=Random.default_rng()
.Change the implementation of L1/L2 regularization from explicit loss penalization to
weight/sign decay (internally chained with the user-specified
optimiser). Breaking: The losses reported in the history will no longer be
penalized, because the penalty is not explicitly computed.
Update documentation to reflect use of
Optimisers.jl
optimisers, instead of Flux.jl native optimisers. And on changes to therng
defaults. Waiting on 🚀 Instate documentation for MLJFlux #252