Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Omnibus PR, including switch to explicit style differentiation #251

Merged
merged 18 commits into from
Jun 10, 2024

Conversation

ablaom
Copy link
Collaborator

@ablaom ablaom commented Apr 30, 2024

This PR combines a number of changes, which for technical reasons could not be easily split up. The most important change, anticipating a Flux 0.15 breakage, is the switch to explicit differentiation; so this PR replaces #230. Shout out to @pat-alt for outlining a solution there.

Closes #221.

To do:

  • Replace implicit style parameter updates with explicit style parameter updates, in
    line with planned Zygote/Flux deprecations.

  • Refactor code to use optimisers from Optimisers.jl with setup/update pattern in
    place of update! pattern. Also, rename private methods train! -> train_epoch
    and fit! -> train to reflect new non-mutating behaviour. This possibly breaks
    some "custom" models that have chosen to overload these technically private methods.

  • (RNG changes.) Change the default value of the model field rng from
    Random.GLOBAL_RNG to Random.default_rng(). Change the seeded RNG, obtained by
    specifying an integer value for rng, from MersenneTwister to Xoshiro.

  • Update the Short builder so that the rng argument of build(::Short, rng, ...)
    is passed on to the Dropout layer, as these layers now support this on a GPU, at
    least for rng=Random.default_rng().

  • Change the implementation of L1/L2 regularization from explicit loss penalization to
    weight/sign decay (internally chained with the user-specified
    optimiser). Breaking: The losses reported in the history will no longer be
    penalized, because the penalty is not explicitly computed.

  • Update documentation to reflect use of Optimisers.jl optimisers, instead of Flux.jl native optimisers. And on changes to the rng defaults. Waiting on 🚀 Instate documentation for MLJFlux #252

@ablaom ablaom marked this pull request as draft April 30, 2024 02:12
@ablaom ablaom mentioned this pull request May 1, 2024
2 tasks
@ablaom ablaom marked this pull request as ready for review May 3, 2024 03:59
@ablaom
Copy link
Collaborator Author

ablaom commented May 3, 2024

This PR is above average in complexity. This means a review is particularly important but it's also going to be more work than usual. @pat-alt Do you have any interest and time to review over the next 3 weeks, say?

My apologies in advance for slightly messy commits. I temporarily lost access to a GPU for local testing and was shooting blind for a while.

for i in 1:n_batches
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
loss(yhat, y[i]) + sum(penalty, Optimisers.trainables(m))/n_batches
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pat-alt Just FYI: Optimisers.jl added trainables which offers this solution to refactor along your original line. It superficially seems like Flux.params, but doesn't seem to suffer the same problems we were seeing (if you can remember back that far 😉)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

So from a user-persective, I'm currently unclear about a few things (perhaps I've just not looked carefully enough!):

  1. Is the regularised_optimiser method user-facing? Where/when is it used here in this specific unit test?
  2. Would it make sense to ship the Penalizer?

@ablaom
Copy link
Collaborator Author

ablaom commented May 22, 2024

@pat-alt Sorry to re-ping, but I'm not sure who else to ask for a review here.
@tiemvanderdeure Would you consider reviewing?

If possible, hoping for a merge in the next 3 weeks. Even a cursory look, would be much appreciated!

@pat-alt
Copy link
Collaborator

pat-alt commented May 23, 2024

Hi! I'll try to have a look as soon as I can (probably on the weekend or next week).

@tiemvanderdeure
Copy link
Contributor

If it is helpful I can also give reviewing this PR a go. Probably won't have time the next few days but early next week should be feasible.

@ablaom
Copy link
Collaborator Author

ablaom commented May 23, 2024

Let's see if @pat-alt is able to finds some time.

@pat-alt pat-alt self-requested a review May 29, 2024 09:08
Copy link
Collaborator

@pat-alt pat-alt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly have a few comments about 1) deprecation and 2) clarification of how regularization is actually implemented on the user-end.

Otherwise, this looks good to me!

@@ -14,13 +14,12 @@
abstract type Builder <: MLJModelInterface.MLJType end

"""
Linear(; σ=Flux.relu, rng=Random.GLOBAL_RNG)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth deprecating this?

end


"""
fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this, too, might be worth deprecating. If I understand this correctly, existing extensions that overload MLJFlux.fit! like here won't work anymore? As in, they should now be overloading the train function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just to double-check @ablaom, fit! will become train and train! will become train_epoch?

@@ -43,7 +49,26 @@ end
const ERR_BUILDER =
"Builder does not appear to build an architecture compatible with supplied data. "

true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng
true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just of curiosity, why this change?

# ensure penalization over an epoch does not scale with the choice of batch size; see
# https://github.com/FluxML/MLJFlux.jl/issues/213.

function regularized_optimiser(model, nbatches)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I read this correctly, it is not possible to specify either L1 or L2? The regularized optimiser will always apply both?

for i in 1:n_batches
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
loss(yhat, y[i]) + sum(penalty, Optimisers.trainables(m))/n_batches
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

So from a user-persective, I'm currently unclear about a few things (perhaps I've just not looked carefully enough!):

  1. Is the regularised_optimiser method user-facing? Where/when is it used here in this specific unit test?
  2. Would it make sense to ship the Penalizer?

@ablaom
Copy link
Collaborator Author

ablaom commented May 29, 2024

Thanks @pat-alt for your review. Much appreciated. I've made a few tweaks in response to the comments.

@pat-alt
Copy link
Collaborator

pat-alt commented Jun 3, 2024

Thanks @ablaom, will have another look today.

@Rockdeldiablo for reference. See in particular the redefinition of fit! and my comments above.

@ablaom
Copy link
Collaborator Author

ablaom commented Jun 10, 2024

@pat-alt How are we doing? Happy with the changes?

@pat-alt
Copy link
Collaborator

pat-alt commented Jun 10, 2024

@pat-alt How are we doing? Happy with the changes?

Sorry, yes, I missed the thumbs up :) Thanks for clarifying!

@pat-alt
Copy link
Collaborator

pat-alt commented Jun 10, 2024

1bd58dd adds deprecations for the fit! and train! methods. I think this is useful for developers who have used the old API in their own packages to extend MLJFlux, as we have done here, for example. Not sure who else has done something like this (and also unsure if this is the intended way to extend MLJFlux), but in any case adding these deprecations should help.

@ablaom ablaom merged commit f38e0cf into dev Jun 10, 2024
6 checks passed
@ablaom ablaom deleted the refactor-regularization branch June 10, 2024 10:26
@ablaom
Copy link
Collaborator Author

ablaom commented Jun 10, 2024

Thanks @pat-alt for your review. 🙏🏾

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Stop using implicit style differentiating
3 participants